1

我最近将我的tesnorflow从Rev8升级到Rev12。在Rev8中,rnn_cell.LSTMCell中的默认“state_is_tuple”标志被设置为False,所以我用列表初始化了我的LSTM Cell,请参阅下面的代码。如何使用元组初始化LSTMCell

#model definition 
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 


#init_state place holder and feed_dict 
def add_placeholders(self): 
    self.init_state = tf.placeholder("float", [None, self.cell_size]) 

def get_feed_dict(self, data, label): 
    feed_dict = {self.input_data: data, 
      self.input_label: reg_label, 
      self.init_state: np.zeros((self.config.batch_size, self.cell_size))} 
    return feed_dict 

在Rev12,默认的“state_is_tuple”标志被设置为True,以使我的旧代码的工作,我不得不把标志明确转向为False。不过,现在我从tensorflow的警告说:“使用级联状态较慢,很快就会被弃用 使用state_is_tuple =真”

我试图初始化一个LSTM细胞元组通过改变占位符定义self.init_state以下内容:

self.init_state = tf.placeholder("float", (None, self.cell_size)) 

,但现在我得到了一个错误信息说:

“‘张量’的对象是不是可迭代”

有谁知道如何使这项工作?

+1

不幸的是,元组是一个复杂的结构。你是否必须*明确地使'init_state'成为一个占位符?使用'cell.zero_state'代替它会好得多。别担心,您可以跨'run_dict'传递状态 – martianwars

回答

1

现在使用cell.zero_state为LSTM提供“零状态”要简单得多。您不需要明确地将初始状态定义为占位符。将其定义为张量,并根据需要进行填充。这是如何工作的,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 

如果要喂一些其他的价值作为初始状态,假设next_state = states[-1]例如,在您的会话计算,并通过它在feed_dict像 -

feed_dict[self.initial_state] = next_state 

在你的问题中,lstm_cell.zero_state()就足够了。


不相关,但请记住,您可以在Feed字典中传递张量和占位符!这就是self.initial_state在上面的例子中的工作原理。查看PTB Tutorial的实例。