我的一位同事指出,当您需要屏蔽Keras中的非RNN输入时,使用sample_weight
代替屏蔽层的选项非常酷。Keras:为非RNN屏蔽零填充输入
就我而言,我在输入中有62列,第63位是响应。前62列中97%以上的非零值包含在前30列中。我试图让这个工作,所以我想重量最后32列在训练中0,本质上是创造一个'穷人的面具'。
这是一个8级分类任务,使用MLP。响应变量已使用Keras中的to_categorical()
函数进行了转换。
这里的实现:
model = Sequential()
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='sigmoid'))
hist = model.fit(X, y,
validation_data=(X_test, ytest),
nb_epoch=epochs_,
batch_size=batch_size_,
callbacks=callbacks_list,
sample_weight = np.array([X.shape[1]-32, 30]))
我得到这个错误:
in standardize_weights
assert y.shape[:sample_weight.ndim] == sample_weight.shape
如何解决我的sample_weight
为 '面具' 输入的前32列?