2017-07-06 80 views
2

我正在使用TensorFlow来实现RNN。我创建了重复单元这样的:dynamic_rnn的输出形状with time_major = True

gru_cell = tf.contrib.rnn.GRUCell(16) 
zero_state = gru_cell.zero_state(1, tf.float32) 
initial_state = tf.placeholder(tf.float32, zero_state.get_shape()) 
out_tensor, final_state = tf.nn.dynamic_rnn(
    gru_cell, 
    parent_tensor, 
    initial_state=initial_state, 
    time_major=False) 
print(out_tensor.get_shape()) 

它报告的输出形状(1, ?, 16),正如我所期望的。第二个维度是?,因为max_time未知。

现在我切换到time_major=True。基于文档,我期望只交换前两个轴,所以输出形状应该是(?, 1, 16)。但事实并非如此。相反,它是(1, 1, 16)。这是怎么回事? max_time还不得而知,那为什么要把它硬编码为1呢?

回答

相关问题