2017-10-14 114 views
0

我试图收集特定张量/(向量/矩阵)在角膜张量内的索引。因此,我试图使用tf.gathertf.where来获取在收集功能中使用的索引。张量流条件轴

然而,当测试相等时,tf.where为匹配值提供元素明智的索引。我希望能够找到张量(向量)的索引(行),这些索引与另一个相等。

这对找到张量中与一组感兴趣的热点向量相匹配的单向量向量特别有用。

我有一些代码来说明到目前为止的缺点:

# standard 
import tensorflow as tf 
import numpy as np 
from sklearn.preprocessing import LabelBinarizer 
sess = tf.Session() 

# one-hot vector encoding labels 
l = LabelBinarizer() 
l.fit(['a','b','c']) 

# input tensor 
t = tf.constant(l.transform(['a','a','c','b', 'a'])) 

# find the indices where 'c' is label 
# ***THIS WORKS*** 
np.all(t.eval(session = sess) == l.transform(['c']), axis = 1) 

# We need to do everything in tensorflow and then wrap in Lambda layer for keras so... 
from keras import backend as K 
# ***THIS DOES NOT WORK*** 
K.all(t.eval(session = sess) == l.transform(['c']), axis = 1) 

# go on from here to get smaller subset of vectors from another tensor with the indicies given by `tf.gather` 

显然上面的代码显示我曾尝试轴得到这个条件的工作,并在numpy的确实很好,但tensorflow版本是不容易从numpy移植。

有没有更好的方法来做到这一点?

回答

1

同样对你做了什么,我们可以用tf.reduce_all这是tensorflow相当于np.all

tf.reduce_all(t.eval(session = sess) == l.transform(['c']), axis = 1)