2017-07-25 169 views
5

问题

我有两个numpy的阵列,Aindicesnumpy的匹配索引尺寸

A具有尺寸m x n x 10000. indices具有尺寸m x n x 5(从argpartition(A, 5)[:,:,:5]输出)。 我想得到一个m×n×5的数组,其中包含对应于indicesA的元素。

尝试

indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]], 
    [500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]]) 
A = np.reshape(range(2 * 3 * 10000), (2,3,10000)) 

A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values 
np.take(A, indices) # shape is right, but it flattens the array first 
np.choose(indices, A) # fails because of shape mismatch. 

动机

我试图得到A[i,j] 5个最大值为每i<mj<n使用np.argpartition因为阵列可以得到相当大的排序顺序。

回答

5

您可以使用advanced-indexing -

m,n = A.shape[:2] 
out = A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices] 

采样运行 -

In [330]: A 
Out[330]: 
array([[[38, 21, 61, 74, 35, 29, 44, 46, 43, 38], 
     [22, 44, 89, 48, 97, 75, 50, 16, 28, 78], 
     [72, 90, 48, 88, 64, 30, 62, 89, 46, 20]], 

     [[81, 57, 18, 71, 43, 40, 57, 14, 89, 15], 
     [93, 47, 17, 24, 22, 87, 34, 29, 66, 20], 
     [95, 27, 76, 85, 52, 89, 69, 92, 14, 13]]]) 

In [331]: indices 
Out[331]: 
array([[[7, 8, 1], 
     [7, 4, 7], 
     [4, 8, 4]], 

     [[0, 7, 4], 
     [5, 3, 1], 
     [1, 4, 0]]]) 

In [332]: m,n = A.shape[:2] 

In [333]: A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices] 
Out[333]: 
array([[[46, 43, 21], 
     [16, 97, 16], 
     [64, 46, 64]], 

     [[81, 14, 43], 
     [87, 24, 47], 
     [27, 52, 95]]]) 

为了得到相对应的最大沿最后轴5种元素的索引,我们将使用argpartition,像这样 -

indices = np.argpartition(-A,5,axis=-1)[...,:5] 

为了保持订单从最高到最低,我们e range(5)而不是5

1

为子孙后代,下面采用Divakar的答案来完成原来的目标,即在排序的顺序返回的前5名值的所有i<m, j<n

m, n = np.shape(A)[:2] 

# get the largest 5 indices for all m, n 
top_unsorted_indices = np.argpartition(A, -5, axis=2)[...,-5:] 

# get the values corresponding to top_unsorted_indices 
top_values = A[np.arange(m)[:,None,None], np.arange(n)[:,None], top_unsorted_indices] 

# sort the top 5 values 
top_sorted_indices = top_unsorted_indices[np.arange(m)[:,None,None], np.arange(n)[:,None], np.argsort(-top_values)]