2017-02-28 122 views
0

我正在使用tensorflow版本0.12.1,并遵循this doc如何完成这个非常简单的分布式培训示例?

我想要做的是在每个工人中加1到count

我的目标是打印>1的结果,但我只得到1

import tensorflow as tf 

FLAGS = tf.app.flags.FLAGS 
tf.app.flags.DEFINE_string('job_name', '', '') 
tf.app.flags.DEFINE_string('ps_hosts', '','') 
tf.app.flags.DEFINE_string('worker_hosts', '','') 
tf.app.flags.DEFINE_integer('task_index', 0, '') 

ps_hosts = FLAGS.ps_hosts.split(',') 
worker_hosts = FLAGS.worker_hosts.split(',') 
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,'worker': worker_hosts}) 
server = tf.train.Server(
        {'ps': ps_hosts,'worker': worker_hosts}, 
        job_name=FLAGS.job_name, 
        task_index=FLAGS.task_index) 

if FLAGS.job_name == 'ps': 
    server.join() 

with tf.device(tf.train.replica_device_setter(
       worker_device="/job:worker/task:%d" % FLAGS.task_index, 
       cluster=cluster_spec)): 
    count = tf.Variable(0) 
    count = tf.add(count,tf.constant(1)) 
    init = tf.global_variables_initializer() 

sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), 
          logdir="./checkpoint/", 
          init_op=init, 
          summary_op=None, 
          saver=None, 
          global_step=None, 
          save_model_secs=60) 

with sv.managed_session(server.target) as sess: 
    sess.run(init) 
    step = 1 
    while step <= 999999999: 
     result = sess.run(count) 
     if step%10000 == 0: 
      print(result) 
     if result>=2: 
      print("!!!!!!!!") 
     step += 1 
    print("Finished!") 

sv.stop() 

回答

0

的问题实际上是独立的分布式执行的,并从这些两行茎:

count = tf.Variable(0) 
    count = tf.add(count,tf.constant(1)) 

tf.add() op是纯功能性的运算,其中,每次用其输出创建一个新的张量它运行,而不是修改其输入。如果你想增加价值,并增加横跨工人可见,则必须使用tf.Variable.assign_add()方法来代替,如下:

count = tf.Variable(0) 
    increment_count = count.assign_add(1) 

然后调用sess.run(increment_count)你的训练循环内递增count变量的值。

相关问题