我最近发现LayerNormBasicLSTMCell是LSTM的一个版本,其中实现了图层标准化和丢弃。因此,我用LayerNormBasicLSTMCell替换了使用LSTMCell的原始代码。这种变化不仅使测试精度从约96%降低到约92%,而且训练花费了更长的时间(约33小时)(原始训练时间约为6小时)。所有参数都是相同的:历元数(10),堆叠层数(3),隐藏矢量大小数(250),退出保持概率(0.5),...硬件也相同。为什么LayerNormBasicLSTMCell比LSTMCell更慢,更不准确?
我的问题是:我在这里做错了什么?
我的原始模型(使用LSTMCell):
# Batch normalization of the raw input
tf_b_VCCs_AMs_BN1 = tf.layers.batch_normalization(
tf_b_VCCs_AMs, # the input vector, size [#batches, #time_steps, 2]
axis=-1, # axis that should be normalized
training=Flg_training, # Flg_training = True during training, and False during test
trainable=True,
name="Inputs_BN"
)
# Bidirectional dynamic stacked LSTM
##### The part I changed in the new model (start) #####
dropcells = []
for iiLyr in range(3):
cell_iiLyr = tf.nn.rnn_cell.LSTMCell(num_units=250, state_is_tuple=True)
dropcells.append(tf.nn.rnn_cell.DropoutWrapper(cell=cell_iiLyr, output_keep_prob=0.5))
##### The part I changed in the new model (end) #####
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=dropcells, state_is_tuple=True)
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=MultiLyr_cell,
cell_bw=MultiLyr_cell,
dtype=tf.float32,
sequence_length=tf_b_lens, # the actual lengths of the input sequences (tf_b_VCCs_AMs_BN1)
inputs=tf_b_VCCs_AMs_BN1,
scope = "BiLSTM"
)
我的新模式(使用LayerNormBasicLSTMCell):
...
dropcells = []
for iiLyr in range(3):
cell_iiLyr = tf.contrib.rnn.LayerNormBasicLSTMCell(
num_units=250,
forget_bias=1.0,
activation=tf.tanh,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0,
dropout_keep_prob=0.5
)
dropcells.append(cell_iiLyr)
...
一个想法:可以[https://stackoverflow.com/questions/43234667/tf-layers-batch-normalization-large-test-error](this)是问题吗? 看来平均值和方差不会在'tf.layers.batch_normalization'中自动更新。我想知道'tf.contrib.rnn.LayerNormBasicLSTMCell'遭受同样的问题。 –
@FariborzGhavamian,我使用了第二种方法来规范化函数(即'update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)'和'with tf.control_dependencies(update_ops):'...)。 –
关于培训时间:我在tensorflow网站上找到了这个:https://www.tensorflow.org/performance/performance_guide#common_fused_ops。你可以打开一个名为'fused'的参数,并加快12%-30%的速度。 –