2017-10-21 128 views
1

我想收集指定轴中指定索引的元素,如下所示。如何收集numpy中特定索引的元素?

x = [[1,2,3], [4,5,6]] 
index = [[2,1], [0, 1]] 
x[:, index] = [[3, 2], [4, 5]] 

这实质上是在pytorch中的收集操作,但正如你所知道的,这在numpy中是不可行的。我想知道在numpy中是否有这样的“聚集”操作?

回答

1
>>> x = np.array([[1,2,3], [4,5,6]]) 
>>> index = np.array([[2,1], [0, 1]]) 
>>> x_axis_index=np.tile(np.arange(len(x)), (index.shape[1],1)).transpose() 
>>> print x_axis_index 
[[0 0] 
[1 1]] 
>>> print x[x_axis_index,index] 
[[3 2] 
[4 5]] 
+0

注意还可以使用'np.arange(len(x))'不确定是否np.range是可取的! –

+0

注意:range(x.shape [0])和range(len(x))给出了一个列表,而np.arange(len(x))和np.arange(x.shape [0])给出了一个数组。数组和列表都有相同的元素。 – Sam17

+0

我想我的问题/陈述更多的是关于性能的问题,在一个非常大的数组中,我怀疑np.range的索引会更快(len vs shape肯定无关紧要)。 –