2016-06-21 54 views
0

为了举例,考虑计算张量流中的内积。我试图在TensorFlow的图表中用不同的方式来引用图表中的事物,当用一个使用feed的会话对它进行评估时。请看下面的代码:你可以给TensorFlow提供什么数据类型作为关键字?

import numpy as np 
import tensorflow as tf 

M = 4 
D = 2 
D1 = 3 
x = tf.placeholder(tf.float32, shape=[M, D], name='data_x') # M x D 
W = tf.Variable(tf.truncated_normal([D,D1], mean=0.0, stddev=0.1)) # (D x D1) 
b = tf.Variable(tf.constant(0.1, shape=[D1])) # (D1 x 1) 
inner_product = tf.matmul(x,W) + b # M x D1 
with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    x_val = np.random.rand(M,D) 
    #print type(x.name) 
    #print x.name 
    name = x.name 
    ans = sess.run(inner_product, feed_dict={name: x_val}) 
    ans = sess.run(inner_product, feed_dict={x.name: x_val}) 
    ans = sess.run(inner_product, feed_dict={x: x_val}) 
    name_str = unicode('data_x', "utf-8") 
    ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work 
    ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work 
    ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work 
    print ans 

了以下工作:

ans = sess.run(inner_product, feed_dict={name: x_val}) 
ans = sess.run(inner_product, feed_dict={x.name: x_val}) 
ans = sess.run(inner_product, feed_dict={x: x_val}) 

,但最后三个:

name_str = unicode('data_x', "utf-8") 
ans = sess.run(inner_product, feed_dict={"data_x": x_val}) #doesn't work 
ans = sess.run(inner_product, feed_dict={'data_x': x_val}) #doesn't work 
ans = sess.run(inner_product, feed_dict={name_str: x_val}) #doesn't work 

没有。我检查了为什么键入x.name,但即使我将它转换为python类型的解释器时,它仍然无法正常工作。 I documentation似乎认为键必须是张量。然而,它接受x.name,而它不是张量(它的一个<type 'unicode'>),是否有人知道发生了什么?


我可以粘贴文件说,它需要一个张量:

可选feed_dict参数允许呼叫者覆盖在图形张量 值。在feed_dict每个键可以是 以下类型之一:

如果键是一张量,该值可以是一个Python标量,字符串, 列表或numpy的ndarray可以转换到相同的D型为 张量。此外,如果密钥是占位符,则将检查值的形状是否与占位符兼容。如果 键是SparseTensor,则该值应该是SparseTensorValue。 feed_dict中的每个 值必须可转换为相应键的dtype 的numpy数组。

回答

0

TensorFlow主要期望tf.Tensor对象作为Feed词典中的键。如果它等于会话图中某些tf.Tensor.name属性,它也将接受一个字符串(可能是bytesunicode)。

在你的例子中,x.name的作品,因为xtf.Tensor,你正在评估它的.name属性。 "data_val"不起作用,因为它是tf.Operation(即x.op)的名称,而不是tf.Tensor的名称,该名称是tf.Operation的输出。如果您打印x.name,您会看到它的值为"data_val:0",这意味着“tf.Operation的第0个输出被称为"data_val"

相关问题