2016-07-26 73 views
0

的形状,我测试下面的代码脚本关于打印张

import tensorflow as tf 

a, b, c = 2, 3, 4 
x = tf.Variable(tf.random_normal([a, b, c], mean=0.0, stddev=1.0, dtype=tf.float32)) 
s = tf.shape(x) 
print(s) 

init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 
print(sess.run(s)) 

运行代码得到如下结果

Tensor("Shape:0", shape=(3,), dtype=int32) 
[2 3 4] 

貌似只有第二打印给人的可读格式。第一次印刷真的在做什么或如何理解第一次输出?

回答

1

s = tf.shape(x)的调用定义了一个符号(但非常简单)的TensorFlow计算,只有在您致电sess.run(s)时才执行。

当您执行print(s) Python将打印TensorFlow知道张量s的所有内容,而无需实际评估它。由于它是tf.shape() op的输出,因此TensorFlow知道它具有类型tf.int32,并且TensorFlow也可以推断出它是长度为3的向量(因为x静态地被称为来自变量定义的3-D张量)。

注意,在很多情况下,你就可以获得更多的形状信息而无需打印特定张的静态形状,使用Variable.get_shape()方法(类似于其表弟Tensor.get_shape())调用张量:

# Print the statically known shape of `x`. 
print(x.get_shape()) 
# ==> "(2, 3, 4)"