所以这是真的很奇怪。我能有以下修改运行代码:
import tensorflow as tf
global_step = tf.train.get_or_create_global_step()
incr_global_step = global_step.assign(global_step + 1)
w = tf.cond(tf.equal(tf.mod(global_step.initialized_value(), 2), 0),
lambda : tf.get_variable('w1', initializer=tf.zeros([], dtype=tf.int32)),
lambda : tf.get_variable('w2', initializer=tf.ones([], dtype=tf.int32)).initialized_value())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(4):
print(sess.run([w, incr_global_step]))
注意,我不得不添加initialized_value
到global_step
和tf.cond
变量'w2'
,但不知何故不到'w1'
(你可以把它也和它将起作用,但如果你不这样做,显然不会提交)。正如文档中提到的那样,这种方法通常不是必需的,它只是给出了一个“视图”,保证在被初始化后使用。为什么tf.cond
要求你使用它,以及为什么这样不一致的方式,我不知道。
除此之外,请注意,您运行代码的方式实际上并不确定。一般你会得到这个:
[1, 1]
[0, 2]
[1, 3]
[0, 4]
但并非总是如此。这是我刚刚得到的一个输出:
[0, 1]
[0, 2]
[1, 3]
[0, 4]
这是因为运行增量和条件的顺序不确定。这是更好地更明确一些相关性,所以如果你想w
后运行的增加,你会怎么做:
import tensorflow as tf
global_step = tf.train.get_or_create_global_step()
incr_global_step = global_step.assign(global_step + 1)
with tf.control_dependencies([incr_global_step]):
w = tf.cond(tf.equal(tf.mod(global_step.initialized_value(), 2), 0),
lambda : tf.get_variable('w1', initializer=tf.zeros([], dtype=tf.int32)).initialized_value(),
lambda : tf.get_variable('w2', initializer=tf.ones([], dtype=tf.int32)).initialized_value())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(4):
print(sess.run([w, incr_global_step]))
这令人惊讶的要求我补充initialized_value
为'w1'
了。这实际上是不一致的。此外,在这种情况下的输出是:
[0, 2]
[1, 3]
[0, 4]
[1, 5]
现在,它使我烦恼,增量的结果从两开始。看起来增量运行一次超过预期。所以我觉得tf.cond
以某种方式强制执行一次额外的第一次运行,这将是其奇怪行为的原因。
如果你想反其道而行之,有w
之前运行的增加,你可以这样做:
import tensorflow as tf
w = tf.cond(tf.equal(tf.mod(global_step.initialized_value(), 2), 0),
lambda : tf.get_variable('w1', initializer=tf.zeros([], dtype=tf.int32)),
lambda : tf.get_variable('w2', initializer=tf.ones([], dtype=tf.int32)).initialized_value())
with tf.control_dependencies([w]):
incr_global_step = global_step.assign(global_step + 1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(4):
print(sess.run([w, incr_global_step]))
是的,我不需要'w1'
的initialized_value
了一次。这产生:
[0, 1]
[1, 2]
[0, 3]
[1, 4]
在这里,增量,我认为,是有道理的。
感谢您的回复,但我觉得global_step默认添加到全局变量集合中。对于我的例子'print(tf.global_variables())'产生'[,,]' –
GeertH
@GeertH Woops你确实是对的,我在我的机器上试过了,但是看到了错误的地方,所以我想''。 train.get_or_create_global_step'没有将新创建的全局步骤添加到全局变量集合(我发现它很奇怪,但无论如何)。我会再看看。 – jdehesa
@GeertH我改变了答案。 – jdehesa