2017-10-10 110 views
0

看我的代码中写道:图间复制的分布式tensorflow?

import tensorflow as tf 


tf.flags.DEFINE_string('job_name', 'ps', 'worker or ps') 
tf.flags.DEFINE_integer('task_index', 0, 'task id') 
FLAGS = tf.flags.FLAGS 

host = '127.0.0.1:' 
cluster = {"ps": [host+'2222'], 
      "worker": [host+'2223', host+'2224']} 
clusterspec = tf.train.ClusterSpec(cluster) 

server = tf.train.Server(cluster, 
         job_name=FLAGS.job_name, 
         task_index=FLAGS.task_index) 

def print_fn(): 
    print('job_name: %s, task_index: %d' % (FLAGS.job_name, FLAGS.task_index)) 


if FLAGS.job_name == 'ps': 
    server.join() 
elif FLAGS.job_name == 'worker': 
    with tf.device(tf.train.replica_device_setter(
     worker_device="/job:worker/task:%d" % FLAGS.task_index, 
     cluster=cluster)): 

     a = tf.Variable(tf.zeros([]), name='a') 
     b = tf.Variable(tf.zeros([]), name='b') 
     op = tf.add(a, b) 
     print(a.device) 
     print(b.device) 
     print(op.device) 
     print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 
print_fn() 

当我运行在cmdpython distributed_replicas.py --job_name=worker --task_index=0,但在此之前运行python distributed_replicas.py --job_name=ps --task_index=0,该程序也适用。 a.deviceb.device都是/job:ps/task:0,但ps server不启动,变量ab如何存储在ps server?并且tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)也包含变量ab,这意味着ab创建在/job:worker/task:0,虽然他们的设备是/job:ps/task:0,所以怎么了?在哪里ab产生的?

回答

0

这是预期的行为。至此,在您的代码中,您只创建了一个图表,该图表不需要/关心作业即可启动并运行。

创建一个会话(或会话的任何变化)后,您会遇到的问题。

点击此处了解详情: https://www.tensorflow.org/extend/architecture

+0

它编码创建分布式'master'的和分区图?这是否意味着每个“工作服务器”上的图形都将被分区?这是否意味着每个'worker server'都会让接收节点和所有'worker server'都会从'ps server'接收相同的参数? – gaussclb

+0

一般而言,万事达(哪个客户端被连接在创建会话时向服务器)将创建的分区的图形,并将其发送到其他服务器(服务器由一个新的会话或tf.train.Server创建任一)。 – xldrx

+0

当使用PS架构中,所有的工人有欧普的Recv在前进的开始通过从一个或多个的PS读取数据和发送op将向后传递期间更新发送回PS。 – xldrx