2016-01-20 81 views
3

我想分割两台机器的最小化功能。在一台机器上,我打电话为“compute_gradients”,在另一台机器上我通过网络发送带有渐变的“apply_gradients”。问题是调用apply_gradients(...)。run(feed_dict)似乎不管用我做什么。我试图取代张量梯度的apply_gradients的插入占位符,TensorFlow远程apply_gradients

variables = [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2] 
    loss = -tf.reduce_sum(y_ * tf.log(y_conv)) 
    optimizer = tf.train.AdamOptimizer(1e-4) 
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
    compute_gradients = optimizer.compute_gradients(loss, variables) 

    placeholder_gradients = [] 
    for grad_var in compute_gradients: 
    placeholder_gradients.append((tf.placeholder('float', shape=grad_var[1].get_shape()) ,grad_var[1])) 
    apply_gradients = optimizer.apply_gradients(placeholder_gradients) 

后来当我收到的梯度我打电话

feed_dict = {} 
    for i, grad_var in enumerate(compute_gradients): 
     feed_dict[placeholder_gradients[i][0]] = tf.convert_to_tensor(gradients[i]) 
    apply_gradients.run(feed_dict=feed_dict) 

然而,当我这样做,我得到

ValueError:用序列设置数组元素。

这只是我尝试过的最新事情,我也尝试过没有占位符的相同解决方案,并且等待创建apply_gradients操作,直到我收到梯度,这会导致不匹配的图形错误。

任何帮助我应该去哪个方向?

+0

请注意,使用占位符(代表'apply_gradients'中的渐变张量)是不必要的,因为您可以为任何变量(张量)提供numpy值。您可以使用'compute_gradients'返回的原始梯度张量。 – Falcon

回答

6

假设每个gradients[i]是,你已经取出的使用一些外的带机构的NumPy的阵列,所述的解决方法是简单地以除去tf.convert_to_tensor()调用构建feed_dict时:

feed_dict = {} 
for i, grad_var in enumerate(compute_gradients): 
    feed_dict[placeholder_gradients[i][0]] = gradients[i] 
apply_gradients.run(feed_dict=feed_dict) 

在每个值feed_dict应该是一个NumPy数组(或其他一些可以普通转换为NumPy数组的对象)。特别是,tf.Tensor对于feed_dict不是有效值。

+0

OMG就是这样,非常感谢你,一直在为此工作一周。 – syzygy

+0

是否记录在某处?除了代码本身之外,我很难找到有关feed_dict工作原理的信息。也许我可以发布一个问题,以便吐出有用的错误消息。 – syzygy

+1

主要文档位于[Session.run()](https://www.tensorflow.org/versions/master/api_docs/python/client.html#Session)的文档中,它描述了可能在'feed_dict'。我想我们也[最近改进](https://github.com/tensorflow/tensorflow/blob/4c3b060b9b063a247172c939e0e1314577b581e3/tensorflow/python/client/session.py#L370)在这种情况下引发的异常,但是可能还没有在当前版本中! – mrry