2017-02-12 847 views
3

我正在寻找一种类似于Python的list.index()函数的TensorFlow方法。如何查找TensorFlow中第一个匹配元素的索引

给出一个矩阵和一个值来查找,我想知道在矩阵的每一行中值的第一次出现。

例如,

m is a <batch_size, 100> matrix of integers 
val = 23 

result = [0] * batch_size 
for i, row_elems in enumerate(m): 
    result[i] = row_elems.index(val) 

我不能假设“VAL”只出现在每行中一次,否则我会使用tf.argmax(米== VAL)已经实现它。在我的情况下,重要的是要获得第一个发生'val'的索引,而不是任何。

回答

6

看起来tf.argmax的工作原理类似np.argmax(根据the test ),当有多个最大值出现时,它将返回第一个索引。 你可以使用tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)来得到你想要的。但是,目前tf.argmax的行为在最大值出现多次的情况下未定义。

如果您担心未定义的行为,您可以按照@Igor Tsvetkov的建议在tf.where的返回值上应用tf.argmin。 例如,

# test with tensorflow r1.0 
import tensorflow as tf 

val = 3 
m = tf.placeholder(tf.int32) 
m_feed = [[0 , 0, val, 0, val], 
      [val, 0, val, val, 0], 
      [0 , val, 0, 0, 0]] 

tmp_indices = tf.where(tf.equal(m, val)) 
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0]) 

with tf.Session() as sess: 
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1] 

注意tf.segment_min将提高InvalidArgumentError时,有不包含val一些列。 row_elems.index(val)row_elems不包含val时也会引发异常。

+0

这非常有帮助!如果我们想将val更新为new_val,该怎么办?我在这里问这个问题:https://stackoverflow.com/questions/45684445/tensorflow-update-first-matching-element-in-each-row – reese0106

1

看起来有点丑,但工程(假设mval都是张量):

idx = list() 
for t in tf.unpack(m, axis=0): 
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val)))) 
idx = tf.pack(idx, axis=0) 

编辑: 作为Yaroslav Bulatov提到的,你可以实现与tf.map_fn相同的结果:

def index1d(t): 
    return tf.reduce_min(tf.where(tf.equal(t, val))) 

idx = tf.map_fn(index1d, m, dtype=tf.int64) 
+1

'map_fn'无需解包就可以做到这一点 –

相关问题