2017-05-30 122 views
0

比方说,我有一个尺寸为[batch_size, 5, 10]的张量,称为my_tensor。 我还有一个尺寸为[batch_size, 1]的另一个张量,其中包含一个名为selecter的索引。如何过滤基于带索引张量的张量流张量?

我想对于过滤my_tensorselecter生产规模[batch_size, 10]新张量,即只选择珍视selecter包含。基本上,它有点减少中间维度(其大小为5)。我觉得tf.where是正确的选择,但不确定。 我真的很感谢你的帮助!

回答

1

解决方法是使用tf.gather_nd

tf.gather_nd(
    my_tensor, 
    tf.stack([tf.range(batch_size), tf.squeeze(selecter)], axis=-1)) 

如果你构建selecter是从一开始1-d可以摆脱squeeze的。

+0

这是完美的。非常感谢你! –

+0

你用什么版本的tensorflow?我有1.3.0和我的tf.gather_nd不接受轴参数。但是,有这个tf.gather。 – omikron

0

替代的解决方案,工作在Tensorflow 1.3:

max_selecter = tf.reduce_max(selecter) + 1 
my_tensor = tf.boolean_mask(
    outputs, 
    tf.logical_xor(
     tf.sequence_mask(my_tensor + 1, max_selecter), 
     tf.sequence_mask(my_tensor, max_selecter) 
    ) 
)