2017-08-08 89 views
-1

给定一个类实例列表,我需要使用tf.tensor将其索引。例如:如何使用TensorFlow张量索引类实例列表

Class Something(): 
    def __init__(self): 
     self.a = 1 
     self.b = 2 

list = [Something() for a in range(0, 10)] 
index_queue = tf.train.range_input_producer(len(list)) 
index = index_queue.dequeue() 
result = list[index] 
tensor = function_that_returns_tensor(result) 
with tf.Session() as sess: 
    sess.run(tensor) 

上面的代码给出以下错误:TypeError: list indices must be integers, not Tensor

并采用tf.gather(list, index)提供了以下错误:

TypeError: Expected binary or unicode string, got <__main__.Something object at 0x7f4529fae2b0> 

任何帮助,将不胜感激。谢谢!

+0

为什么你使用'tf.constant(..)'? 'list [2]'会正常工作... –

+0

我修改了这个问题。所以index是一个tf.tensor,它在执行图时会有一些价值。 –

回答

0

问题出在TensorFlow工作原理的核心机制上。当您调用tf.train.range_input_producer(len(list))tf.constant等TensorFlow方法时,您实际上并不是运行这些操作。您只需将这些操作添加到TensorFlow计算图。然后您必须使用tf.Session实例的run方法来运行这些操作并从中获取结果。 TypeError: list indices must be integers, not Tensor告诉您,您将计算图上的张量引用作为索引传递,而不是运行产生张量的操作返回的结果。请参阅this TensorFlow documentation

+0

非常感谢您的回复。是的,我了解Tensorflow的整体机制。我报告的错误是我在tf.Session中运行这些操作时得到的。我相应地修改了这个问题。 –

+0

@UmarIqbal,在您更新的代码中,您仍然将张量的引用作为索引传递给列表,而不是从运行'tf.Session'返回的内容。在你的代码中,'index'是一个张量的引用,而不是一个整数。要从它得到一个整数,你需要运行'index_value = sess.run(index)'。然后'list [index_value]'将起作用。 – golmschenk

+0

@UmarIqbal,但是请注意,您的代码存在另一个问题,它来自使用队列。如果你做出我说过的改变,你的代码似乎会挂起。这是因为队列需要队列运行器才能工作。关于[这里]的更多信息(https://www.tensorflow.org/programmers_guide/threading_and_queues)。但是你最初使用常量(或者其他非队列生成张量)而不是队列的例子应该可以工作。 – golmschenk