我在Keras中实现了我自己的复发层,并且在step
函数内部我希望能够访问所有时间步骤中的隐藏状态,而不仅仅是最后一个状态如默认,以便我可以做一些事情,比如及时向后添加跳过连接。返回Keras中RNN中所有时间步骤的所有状态
我试图修改tensorflow后端K.rnn
内的_step
以返回到目前为止所有隐藏状态。我最初的想法是简单地将每个隐藏状态存储到TensorArray中,然后将所有这些状态都传递给step_function
(即我的层中的step
函数)。我现在的修改功能;下面,写每个隐藏状态转变为TensorArray states_ta_t
:
def _step(time, output_ta_t, states_ta_t, *states):
current_input = input_ta.read(time)
# Here I'd like to return all states up to current time
# and pass to step_function, instead of just the last
states = [states_ta_t.read(time)]
output, new_states = step_function(current_input,
tuple(states) +
tuple(constants))
for state, new_state in zip(states, new_states):
new_state.set_shape(state.get_shape())
states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t, states_ta_t) + tuple(new_states)
这个版本只返回一个状态,就像当初的实施,并可以作为一个正常的RNN。我如何获取迄今为止所有的状态,存储在数组中,并传递给step_function
?感觉这应该是非常简单的,但我对TensorArrays不是很熟练......
(注意:在展开的版本中这比在符号上更容易,但不幸的是我会用完使用我的实验展开的版本)
欢迎来到SO。请阅读这个[如何回答] (http://stackoverflow.com/help/how-to-answer)提供高质量的答案。 – thewaywewere
感谢您的咨询!我会读它,并尽力提高我的答案:) – Carefree0910
非常感谢@ Carefree0910。这回答了我的问题,我不知道我可以用这种方式对它们进行切片:-)我最终意识到,我可能会以这种方式使用太多的内存,通过在state_ta_t中一次保持所有状态。所以我最终创建了两个TensorArrays,一个用于当前时间步和一个上一个时间步,用“clear_after_read = True”,这样我只能访问一个额外的状态,但只保留两个状态随时在内存中。 – jodles