2017-05-30 115 views
1

我在Keras中实现了我自己的复发层,并且在step函数内部我希望能够访问所有时间步骤中的隐藏状态,而不仅仅是最后一个状态如默认,以便我可以做一些事情,比如及时向后添加跳过连接。返回Keras中RNN中所有时间步骤的所有状态

我试图修改tensorflow后端K.rnn内的_step以返回到目前为止所有隐藏状态。我最初的想法是简单地将每个隐藏状态存储到TensorArray中,然后将所有这些状态都传递给step_function(即我的层中的step函数)。我现在的修改功能;下面,写每个隐藏状态转变为TensorArray states_ta_t

def _step(time, output_ta_t, states_ta_t, *states): 
      current_input = input_ta.read(time) 
      # Here I'd like to return all states up to current time 
      # and pass to step_function, instead of just the last 
      states = [states_ta_t.read(time)] 
      output, new_states = step_function(current_input, 
               tuple(states) + 
               tuple(constants)) 
      for state, new_state in zip(states, new_states): 
       new_state.set_shape(state.get_shape()) 
      states_ta_t = states_ta_t.write(time+1, new_states[0]) # record states 
      output_ta_t = output_ta_t.write(time, output) 
      return (time + 1, output_ta_t, states_ta_t) + tuple(new_states) 

这个版本只返回一个状态,就像当初的实施,并可以作为一个正常的RNN。我如何获取迄今为止所有的状态,存储在数组中,并传递给step_function?感觉这应该是非常简单的,但我对TensorArrays不是很熟练......

(注意:在展开的版本中这比在符号上更容易,但不幸的是我会用完使用我的实验展开的版本)

回答

2

记忆 - 编辑 -

我发现,我误解你的问题,我非常抱歉为...

总之,试试这个:

states = states_ta_t.stack()[:time] 

下面是一些说明:你确实已经将所有这些状态存储在states_ta_t中,但你只能通过最后一个到step_function

你已经在你的代码做的是:

# Param 'time' refers to 'current time step' 
states = [states_ta_t.read(time)] 

这意味着,你从states_ta_t读取“当前”的状态,换句话说,最后的状态。

如果你想做一些切片,也许stack功能将有所帮助。例如:

states = states_ta_t.stack()[:time] 

但我不知道这是否是一个正确的实施,因为我不熟悉TensorArray要么...

希望它能帮助!如果没有,如果你愿意留下评论并与我讨论,这是我的荣幸!

+1

欢迎来到SO。请阅读这个[如何回答] (http://stackoverflow.com/help/how-to-answer)提供高质量的答案。 – thewaywewere

+1

感谢您的咨询!我会读它,并尽力提高我的答案:) – Carefree0910

+1

非常感谢@ Carefree0910。这回答了我的问题,我不知道我可以用这种方式对它们进行切片:-)我最终意识到,我可能会以这种方式使用太多的内存,通过在state_ta_t中一次保持所有状态。所以我最终创建了两个TensorArrays,一个用于当前时间步和一个上一个时间步,用“clear_after_read = True”,这样我只能访问一个额外的状态,但只保留两个状态随时在内存中。 – jodles

相关问题