2017-02-10 407 views
7

我使用TensorFlow构建深度学习模型。对TensorFlow来说是新的。如何用累积梯度更新模型参数?

由于某种原因,我的模型的批量大小有限,那么这个有限的批量大小将使模型具有很高的方差。

所以,我想用一些技巧来增大批量。我的想法是存储每个小批量的梯度,例如64个小批量,然后将梯度求和在一起,使用这64个小批量训练数据的平均梯度来更新模型的参数。

这意味着对于前63个小批次,不要更新参数,并且在64个小批量后,只更新一次模型参数。

但是,由于TensorFlow是基于图形的,有谁知道如何实现这个想要的功能?

非常感谢。

+0

[sync replicas optimizer](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/sync_replicas_optimizer.py)你在找什么? –

+0

似乎我可以存储所有中间渐变,然后计算渐变的均值,然后更新模型参数。 – weixsong

+0

同步副本优化器似乎适用于多GPU并行训练。我会研究它是否可以利用它。 – weixsong

回答

4

我发现这里的解决方案:https://github.com/tensorflow/tensorflow/issues/3994#event-766328647

opt = tf.train.AdamOptimizer() 
tvs = tf.trainable_variables() 
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]           
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars] 
gvs = opt.compute_gradients(rmse, tvs) 
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)] 
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)]) 

在训练循环:

while True: 
    sess.run(zero_ops) 
    for i in xrange(n_minibatches): 
     sess.run(accum_ops, feed_dict=dict(X: Xs[i], y: ys[i])) 
    sess.run(train_step) 

但这代码似乎不是很干净漂亮,没有人知道如何优化这些代码?

2

我有同样的问题,只是想通了。

首先得到符号梯度,然后将累积梯度定义为tf.Variables。 (看来tf.global_variables_initializer()具有限定grads_accum之前运行的。我有错误,否则,不知道为什么。)

tvars = tf.trainable_variables() 
optimizer = tf.train.GradientDescentOptimizer(lr) 
grads = tf.gradients(cost, tvars) 

# initialize 
tf.local_variables_initializer().run() 
tf.global_variables_initializer().run() 

grads_accum = [tf.Variable(tf.zeros_like(v)) for v in grads] 
update_op = optimizer.apply_gradients(zip(grads_accum, tvars)) 

在训练中可以积累梯度在每批次(保存在gradients_accum),之后更新模型运行64个批次:

feed_dict = dict() 
for i, _grads in enumerate(gradients_accum): 
    feed_dict[grads_accum[i]] = _grads 
sess.run(fetches=[update_op], feed_dict=feed_dict) 

可以参考tensorflow/tensorflow/python/training/optimizer_test.py例如使用,特别是这样的功能:testGradientsAsVariables()

希望它有帮助。

+0

我不认为这个代码与这个问题有关。大家总结的是什么?另外,在你提到的例子中,渐变不会累积;他们计算w.r.t.到两个输入独立。 –