2017-08-03 117 views
1

我正在使用tf.while_loop动态连接张量。tf.while_loop仅考虑最后一次迭代

代码

embeds_raw = tf.constant(np.array([ 
    [1, 1], 
    [1, 1], 
    [2, 2], 
    [3, 3], 
    [3, 3], 
    [3, 3] 
], dtype='float32')) 
embeds = tf.Variable(initial_value=embeds_raw) 
container_variable = tf.zeros([512], dtype=tf.int32, name='container_variable') 
sen_len = tf.placeholder('int32', shape=[None], name='sen_len') 
# max_l = tf.reduce_max(sen_len) 
current_size = tf.shape(sen_len)[0] 
padded_sen_len = tf.pad(sen_len, [[0, 512 - current_size]], 'CONSTANT') 
added_container_variable = tf.add(container_variable, padded_sen_len) 
u1 = tf.TensorArray(dtype=tf.float32, size=512, clear_after_read=False) 
u1 = u1.split(embeds, added_container_variable) 
res = tf.split(embeds, added_container_variable) 

i = tf.constant(0, shape=(), dtype='int32', name='i') 
x = tf.Variable(tf.constant(0, shape=[2, 2], dtype=tf.float32), dtype=tf.float32) 

def condition(_i, _x): 
    return tf.less(_i, current_size) 

def body(_i, _x): 
    return _i + 1, tf.concat([x, u1.read(_i)], axis=0) 

idx, x = tf.while_loop(
    condition, 
    body, 
    [i, x], 
    shape_invariants=[tf.TensorShape([]), tf.TensorShape([None, 2])], 
) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    sents = sess.run(x, feed_dict={sen_len: [2, 1, 3]}) 
    print(sents) 
    print(len(res)) 

会发生什么事是,它在每次迭代串联但是在丢弃修改。换句话说,新的迭代不使用以前的结果。

这是我得到的输出:

[[ 0. 0.] 
[ 0. 0.] 
[ 3. 3.] 
[ 3. 3.] 
[ 3. 3.]] 

而我所需的输出是:

[[ 0. 0.] 
[ 0. 0.] 
[ 1. 1.] 
[ 1. 1.] 
[ 2. 2.] 
[ 3. 3.] 
[ 3. 3.] 
[ 3. 3.]] 

回答

1

这是因为该行:

return _i + 1, tf.concat([x, u1.read(_i)], axis=0) 

您应将其更改为:

return _i + 1, tf.concat([_x, u1.read(_i)], axis=0) 
相关问题