您可以使用tf.gather_nd()
有这样的代码:
import tensorflow as tf
# B = 3
# N = 4
# M = 2
# [B x N x 3]
data = tf.constant([
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
[[100, 101, 102], [103, 104, 105], [106, 107, 108], [109, 110, 111]],
[[200, 201, 202], [203, 204, 205], [206, 207, 208], [209, 210, 211]],
])
# [B x M]
indices = tf.constant([
[0, 2],
[1, 3],
[3, 2],
])
indices_shape = tf.shape(indices)
indices_help = tf.tile(tf.reshape(tf.range(indices_shape[0]), [indices_shape[0], 1]) ,[1, indices_shape[1]]);
indices_ext = tf.concat([tf.expand_dims(indices_help, 2), tf.expand_dims(indices, 2)], axis = 2)
new_data = tf.gather_nd(data, indices_ext)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('data')
print(sess.run(data))
print('\nindices')
print(sess.run(indices))
print('\nnew_data')
print(sess.run(new_data))
new_data
将是:
[[[ 0 1 2]
[ 6 7 8]]
[[103 104 105]
[109 110 111]]
[[209 210 211]
[206 207 208]]]
它的工作原理!但是你能解释一下这三行代码的逻辑是什么吗('indices_help','indices_ext','new_data')? –
'tf.gather_nd'需要切片索引(两个坐标为''),但我们只有一个坐标'',在索引数组中有隐式坐标'b'。因此,我们应该通过添加_batch_的索引来扩展'indices'的每个元素。就是这个。首先我们制作索引为“批次”的“索引_帮助”,然后将其与源'索引'连接起来 –