2016-08-03 229 views
0

我已经张量的定义如下:如何在Tensorflow中从张量中获取特定行?

idx = tf.constant([0, 2]) 

现在我想利用temp_var一个子集在那些:

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]])) 

我也有行索引的阵列,以从张量中获取指标即idx

我知道,要采取单一索引或切片,我们可以做这样的事情

temp_var[single_row_index, :] 

temp_var[start:end, :] 

但如何读取行由idx阵列表示? 类似于temp_var[idx, :]

回答

2

tf.gather() op正好满足您的需求:它从矩阵(或从N维张量中选择一般(N-1)维片)中选择行。以下是它如何在你的情况下工作:

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])) 
idx = tf.constant([0, 2]) 

rows = tf.gather(temp_var, idx) 

init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 

print(sess.run(rows)) # ==> [[1, 2, 3], [7, 8, 9]]