2015-06-21 109 views
1

我有2D narrays的集合,这取决于两个整索引,说p1和p2,具有相同形状的每个矩阵。np.argmax上多维数组,保持一些索引固定

然后我需要找到,对于每对(P1,P2),该基质的最大值和这些最大值的索引。 一个不重要,虽然速度慢,办法做到这一点会是做这样的事情

import numpy as np 
import itertools 
range1=range(1,10) 
range2=range(1,20) 

for p1,p2 in itertools.product(range1,range1): 
    mat=np.random.rand(10,10) 
    index=np.unravel_index(mat.argmax(), mat.shape) 
    m=mat[index] 
    print m, index 

对于我的应用程序,这是不幸的是太慢了,我想由于双for循环使用。 因此我试图将所有东西都打包在一个4维数组(比如BigMatrix)中,其中前两个坐标是索引p1,p2,其他2个是矩阵的坐标。

>>res=np.amax(BigMatrix,axis=(2,3)) 
    >>res.shape 
     (10,20) 
    >>res[p1,p2]==np.amax(BigMatrix[p1,p2,:,:]) 
     True 

作品如预期的,因为它遍历2和3轴的np.amax命令。我该如何为np.argmax做同样的事情?请记住,速度很重要。

非常感谢你提前,

恩佐

回答

1

在这里,这对我的作品,其中Mat是大矩阵。

# flatten the 3 and 4 dimensions of Mat and obtain the 1d index for the maximum 
# for each p1 and p2 
index1d = np.argmax(Mat.reshape(Mat.shape[0],Mat.shape[1],-1),axis=2) 

# compute the indices of the 3 and 4 dimensionality for all p1 and p2 
index_x, index_y = np.unravel_index(index1d,Mat[0,0].shape) 

# bring the indices into the right shape 
index = np.array((index_x,index_y)).reshape(2,-1).transpose() 

# get the maxima 
max_val = np.amax(Mat,axis=(2,3)).reshape(-1) 

# combine maxima and indices 
sol = np.column_stack((max_val,index)) 

print sol 
+0

非常感谢你,这是快速和正确。我应该更深入地介绍numpy中的索引问题。 –