2013-03-18 107 views
4

如果我有一个这样的ndarray:numpy的3D到2D变换基于二维掩模阵列

>>> a = np.arange(27).reshape(3,3,3) 
>>> a 
array([[[ 0, 1, 2], 
     [ 3, 4, 5], 
     [ 6, 7, 8]], 

     [[ 9, 10, 11], 
     [12, 13, 14], 
     [15, 16, 17]], 

     [[18, 19, 20], 
     [21, 22, 23], 
     [24, 25, 26]]]) 

我知道可以使用np.max(axis=...)得到沿某一轴的最大:

>>> a.max(axis=2) 
array([[ 2, 5, 8], 
     [11, 14, 17], 
     [20, 23, 26]]) 

替代地,我可以得到沿该轴的指数,其对应于来自以下的最大值:

>>> indices = a.argmax(axis=2) 
>>> indices 
array([[2, 2, 2], 
     [2, 2, 2], 
     [2, 2, 2]]) 

我的问题 - 给定数组indices和数组a,有没有一个优雅的方法来重现由a.max(axis=2)返回数组的数组?

这可能会工作:

import itertools as it 
import numpy as np 
def apply_mask(field,indices): 
    data = np.empty(indices.shape) 

    #It seems highly likely that there is a more numpy-approved way to do this. 
    idx = [range(i) for i in indices.shape] 
    for idx_tup,zidx in zip(it.product(*idx),indices.flat): 
     data[idx_tup] = field[idx_tup+(zidx,)] 
    return data 

但是,它似乎很哈克/低效。它也不允许我在除“最后”轴以外的任何轴上使用它。是否有一个numpy函数(或使用神奇的numpy索引)来完成这项工作?天真的a[:,:,a.argmax(axis=2)]不起作用。

UPDATE

看来下面也可以(并且是一个更好一点):

import numpy as np 
def apply_mask(field,indices): 
    data = np.empty(indices.shape) 

    for idx_tup,zidx in np.ndenumerate(indices): 
     data[idx_tup] = field[idx_tup+(zidx,)] 

    return data 

我想这样做,因为我想提取指数基于1个数组中的数据(通常使用argmax(axis=...)),并使用这些索引将数据从一堆其他(等同形状的)数组中取出。我愿意选择其他方法来完成这个任务(例如,使用布尔型屏蔽阵列)。但是,我喜欢使用这些“索引”数组的“安全”。有了这个,我保证有适当数量的元素来创建一个新的数组,这个数组看起来像是3D场中的2D切片。

+0

淘numpy的/ SciPy的几分钟后,我同意你的自我答案( S)。最终,走“黑客”路线更具可读性和实用性。尽管如此,你可能对函数'ravel','ravel_multi_index','unravel_index','flat'和'flatten'感兴趣。 – 2013-03-18 04:19:18

回答

2

下面是一些神奇的numpy索引,它可以做你想做的,但不幸的是它很难读。

def apply_mask(a, indices, axis): 
    magic_index = [np.arange(i) for i in indices.shape] 
    magic_index = np.ix_(*magic_index) 
    magic_index = magic_index[:axis] + (indices,) + magic_index[axis:] 
    return a[magic_index] 

或同样不可读:

def apply_mask(a, indices, axis): 
    magic_index = np.ogrid[tuple(slice(i) for i in indices.shape)] 
    magic_index.insert(axis, indices) 
    return a[magic_index] 
+1

我其实喜欢'ogrid'路线比'np.ix_'好一点。无论出于何种原因,'np.ix_'对我的口味总是显得有点太神奇。 – mgilson 2013-03-18 12:09:43

-1

我用index_at()创建完整的索引:

import numpy as np 

def index_at(idx, shape, axis=-1): 
    if axis<0: 
     axis += len(shape) 
    shape = shape[:axis] + shape[axis+1:] 
    index = list(np.ix_(*[np.arange(n) for n in shape])) 
    index.insert(axis, idx) 
    return tuple(index) 

a = np.random.randint(0, 10, (3, 4, 5)) 

axis = 1 
idx = np.argmax(a, axis=axis) 
print a[index_at(idx, a.shape, axis=axis)] 
print np.max(a, axis=axis)