2017-05-03 50 views
6

阅读https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py的功能average_gradients以下注释提供:Note that this function provides a synchronization point across all towers.是功能average_gradients阻塞调用,并是什么意思synchronization pointTensorflow CIFAR同步点

我认为这是一个阻塞呼叫,因为为了计算梯度的平均值,每个梯度都必须单独计算?但是,等待所有单个渐变计算的阻止代码在哪里?

回答

6

average_gradients本身不是阻塞功能。它可能是张量流操作的另一个函数,这仍然是一个同步点。是什么让它阻塞是因为它使用参数tower_grads,这取决于在前面的for循环中创建的所有图形。

基本上这里发生的是创建训练图。首先,在for循环for i in xrange(FLAGS.num_gpus)中创建了几个图表“线程”。每个看起来像这样:

计算损失 - >计算梯度 - >附加到tower_grads

每个的那些曲线图“线程”被分配给一个不同的GPU通过with tf.device('/gpu:%d' % i)并且每一个可以运行彼此独立的(并且稍后将并行运行)。现在下一次使用tower_grads时没有设备规范,它会在主设备上创建一个图继续,将所有这些单独的图“线程”绑定到一个单独的图上。在运行average_gradients函数中的图之前,Tensorflow将确保作为创建tower_grads的一部分的每个图形“线程”都已完成。因此稍后调用sess.run([train_op, loss])时,这将成为图的同步点。