2017-05-25 101 views
2

我正在使用dynamic_rnn和LSTMCell,后者放出一个包含内部状态的LSTMStateTuple。调用这个对象的重塑(我的错误)会导致一个张量,而不会在创建图表时造成任何错误。通过图形输入输入时,我在运行时也没有出现任何错误。在LSTMStateTuple上调用变形将它变成张量

代码:

cell = tf.contrib.rnn.LSTMCell(size, state_is_tuple=True, ...) 
outputs, states = tf.nn.dynamic_rnn(cell, inputs, ...) 
print(states) # state is an LSTMStateTuple 
states = tf.reshape(states, [-1, size]) 
print(states) # state is a tensor of shape [?, size] 

这是一个错误(我问,因为它没有记录任何地方)?什么是重构张量控制?

回答

0

我已经进行了类似的实验,它可能给你一些提示:

>>> s = tf.constant([[0, 0, 0, 1, 1, 1], 
        [2, 2, 2, 3, 3, 3]]) 
>>> t = tf.constant([[4, 4, 4, 5, 5, 5],                
        [6, 6, 6, 7, 7, 7]]) 
>>> g = tf.reshape((s, t), [-1, 3]) # <tf.Tensor 'Reshape_1:0' shape=(8, 3) dtype=int32> 
>>> sess.run(g) 
array([[0, 0, 0], 
     [1, 1, 1], 
     [2, 2, 2], 
     [3, 3, 3], 
     [4, 4, 4], 
     [5, 5, 5], 
     [6, 6, 6], 
     [7, 7, 7]], dtype=int32) 

我们可以看到,它只是串接两个张量在第一维和进行整形。由于LSTMStateTuple就像一个namedtuple,所以它和元组的效果相同,我认为这也是你的情况。

让我们走的更远,

>>> st = tf.contrib.rnn.LSTMStateTuple(s, t) 
>>> gg = tf.reshape(st, [-1, 3]) 
>>> sess.run(gg) 
    array([[0, 0, 0], 
      [1, 1, 1], 
      [2, 2, 2], 
      [3, 3, 3], 
      [4, 4, 4], 
      [5, 5, 5], 
      [6, 6, 6], 
      [7, 7, 7]], dtype=int32) 

我们可以看到,如果我们创建一个LSTMStateTuple,结果证实了我们的假设。