2016-11-25 118 views
1

有谁知道如何提取rank 2张量中每行最大的n个最大值?Tensorflow top张量中的n值

例如,如果我想形状[2,4]的张量的顶部2的值与值:

[[40,30,20,10],[10,20,30,40 ]]

所需的条件矩阵会是什么样子: [真,真,假,假],[假,假,真,真]

一旦有了条件矩阵,我可以使用tf.select选择实际值。

谢谢你的协助!

回答

6

你可以做到这一点使用内置tf.nn.top_k功能:

a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]]) 
b = tf.nn.top_k(a, 2) 

print(sess.run(b)) 
TopKV2(values=array([[40, 30], 
    [40, 30]], dtype=int32), indices=array([[0, 1], 
    [3, 2]], dtype=int32)) 

print(sess.run(b).values)) 
array([[40, 30], 
     [40, 30]], dtype=int32) 

要得到布尔True/False值,你可以先得到第k值,然后使用tf.greater_equal

kth = tf.reduce_min(b.values) 
top2 = tf.greater_equal(a, kth) 
print(sess.run(top2)) 
array([[ True, True, False, False], 
     [False, False, True, True]], dtype=bool) 
+0

谢谢寻求帮助。有没有一种简单的方法来使用这些top_k值来获得True和False值的原始大小的张量? –

+0

是的,请参阅编辑答案。 – sygi

+1

谢谢!你的意思是tf.greater_equal(a,kth)? –