我最近将我的tesnorflow从Rev8升级到Rev12。在Rev8中,rnn_cell.LSTMCell中的默认“state_is_tuple”标志被设置为False,所以我用列表初始化了我的LSTM Cell,请参阅下面的代码。如何使用元组初始化LSTMCell
#model definition
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)
#init_state place holder and feed_dict
def add_placeholders(self):
self.init_state = tf.placeholder("float", [None, self.cell_size])
def get_feed_dict(self, data, label):
feed_dict = {self.input_data: data,
self.input_label: reg_label,
self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
return feed_dict
在Rev12,默认的“state_is_tuple”标志被设置为True,以使我的旧代码的工作,我不得不把标志明确转向为False。不过,现在我从tensorflow的警告说:“使用级联状态较慢,很快就会被弃用 使用state_is_tuple =真”
我试图初始化一个LSTM细胞元组通过改变占位符定义self.init_state以下内容:
self.init_state = tf.placeholder("float", (None, self.cell_size))
,但现在我得到了一个错误信息说:
“‘张量’的对象是不是可迭代”
有谁知道如何使这项工作?
不幸的是,元组是一个复杂的结构。你是否必须*明确地使'init_state'成为一个占位符?使用'cell.zero_state'代替它会好得多。别担心,您可以跨'run_dict'传递状态 – martianwars