2017-05-05 74 views
1

让我们考虑一个numpy的矩阵,o越来越多指数值从张量一次,在tensorflow

如果我们要使用numpy的使用下面的功能:

o[np.arange(x), column_array] 

我能得到来自一个numpy数组的多个索引。

我试图用tensorflow做同样的事情,但它不像我所做的那样工作。当o是张量流张量时;

o[tf.range(0, x, 1), column_array] 

我得到以下错误:

TypeError: can only concatenate list (not "int") to list 

我能做些什么?

回答

1

您可能希望看到tf.gather_ndhttps://www.tensorflow.org/api_docs/python/tf/gather_nd

import tensorflow as tf 
import numpy as np 

tensor = tf.placeholder(tf.float32, [2,2]) 
indices = tf.placeholder(tf.int32, [2,2]) 
selected = tf.gather_nd(tensor, indices=indices) 

with tf.Session() as session: 
    data = np.array([[0.1,0.2],[0.3,0.4]]) 
    idx = np.array([[0,0],[1,1]]) 
    result = session.run(selected, feed_dict={indices:idx, tensor:data}) 
    print(result) 

,其结果将是[ 0.1 0.40000001]

3

你可以试试tf.gather_nd(),为How to select rows from a 3-D Tensor in TensorFlow?这篇文章建议。 以下是从矩阵o获取多个索引的示例。

o = tf.constant([[1, 2, 3, 4], 
       [5, 6, 7, 8], 
       [9, 10, 11, 12], 
       [13, 14, 15, 16]]) 
# [row_index, column_index], I don’t figure out how to 
# combine row vector and column vector into this form. 
indices = tf.constant([[0, 0], [0, 1], [2, 1], [2, 3]]) 

result = tf.gather_nd(o, indices) 

with tf.Session() as sess: 
    print(sess.run(result)) #[ 1 2 10 12]