2017-08-27 85 views
0
hidden_size = 1 
batch_size = 1 
seq_len = 3 
feature_dim = 1 
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_size) 
init_state = tf.placeholder_with_default(
    lstm_cell.zero_state(batch_size=batch_size, dtype=tf.float32), shape = [2, batch_size, hidden_size]) 
l = tf.Variable([seq_len]) 
v = tf.Variable(tf.random_normal(shape=[batch_size, seq_len, feature_dim], mean = 0, stddev = 0.01), name = 'v', trainable=True, dtype=tf.float32) 
otuput, out_state = tf.nn.dynamic_rnn(lstm_cell, v, [seq_len], initial_state= init_state) 
with tf.Session() as ses: 
    ses.run(tf.global_variables_initializer()) 

dynamic_rnn函数的参数我写tensorflow,并运行它我得到这个错误类型错误在tensorflow

TypeError         Traceback (most recent call last) 
<ipython-input-55-f105d2bb8ade> in <module>() 
     8 l = tf.Variable([seq_len]) 
     9 v = tf.Variable(tf.random_normal(shape=[batch_size, seq_len, feature_dim], mean = 0, stddev = 0.01), name = 'v', trainable=True, dtype=tf.float32) 
---> 10 otuput, out_state = tf.nn.dynamic_rnn(lstm_cell, v, [seq_len], initial_state= init_state) 
    11 with tf.Session() as ses: 
    12  ses.run(tf.global_variables_initializer()) 
. 
. 
. 
TypeError: 'Tensor' object is not iterable. 

问题是什么代码?

回答

1

tf.nn.dynamic_rnn参数initial_state应该是完全定义的tensor,而不是占位符。用此线代替init_state将修复错误

​​