2017-02-17 65 views
2

我的一位同事指出,当您需要屏蔽Keras中的非RNN输入时,使用sample_weight代替屏蔽层的选项非常酷。Keras:为非RNN屏蔽零填充输入

就我而言,我在输入中有62列,第63位是响应。前62列中97%以上的非零值包含在前30列中。我试图让这个工作,所以我想重量最后32列在训练中0,本质上是创造一个'穷人的面具'。

这是一个8级分类任务,使用MLP。响应变量已使用Keras中的to_categorical()函数进行了转换。

这里的实现:

model = Sequential() 
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu')) 
model.add(Dense(8, init='uniform', activation='sigmoid')) 
hist = model.fit(X, y, 
       validation_data=(X_test, ytest), 
       nb_epoch=epochs_, 
       batch_size=batch_size_, 
       callbacks=callbacks_list, 
       sample_weight = np.array([X.shape[1]-32, 30])) 

我得到这个错误:

in standardize_weights 
assert y.shape[:sample_weight.ndim] == sample_weight.shape 

如何解决我的sample_weight为 '面具' 输入的前32列?

回答

2

样品重量是不工作这样的:

sample_weight : optional array of the same length as x , containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length) , to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile() . source

换句话说,该设置使上样本训练数据的,而不是在每个样品的特征的不同的权重。这仅用于训练步骤。 我想你应该使用遮罩,如果你不希望图层使用这些功能。或者只是从你的数据集中删除它们?或者,如果它不太复杂,让网络自己学习哪些有用的功能。

这有帮助吗?