你想要的就像是一个自定义的减少功能。如果要保持像最大值的指数indices
话,我会建议使用tf.reduce_max
:
max_params = tf.reduce_max(params_tensor, reduction_indices=[2])
否则,这里是让你想要什么(张量对象是不可转让的,所以我们创建的二维表的一种方式张量和它使用tf.pack
)包装:
import tensorflow as tf
import numpy as np
with tf.Graph().as_default():
params_tensor = tf.pack(np.random.randint(1,256, [5,5,10]).astype(np.int32))
indices = tf.pack(np.random.randint(1,10,[5,5]).astype(np.int32))
output = [ [None for j in range(params_tensor.get_shape()[1])] for i in range(params_tensor.get_shape()[0])]
for i in range(params_tensor.get_shape()[0]):
for j in range(params_tensor.get_shape()[1]):
output[i][j] = params_tensor[i,j,indices[i,j]]
output = tf.pack(output)
with tf.Session() as sess:
params_tensor,indices,output = sess.run([params_tensor,indices,output])
print params_tensor
print indices
print output
来源
2017-03-05 00:54:40
Ali
谢谢。这就是我想要的 –
对于我来说,将参数平滑为(-1,256)和索引为(-1)似乎更容易些,使用gather_nd然后解开平铺,而不是平铺 – Maxim
geez。这真的很聪明,也@ @ @ $。烟雾从我头上读出来。它也很有用!谢谢。为我节省了很多时间。 – thang