2015-07-12 123 views
1

我写了一个函数,它接受一组随机笛卡尔坐标并返回保留在某个空间域内的子集。为了说明:使用数组索引在3D数组上应用2D数组函数

grid = np.ones((5,5)) 
grid = np.lib.pad(grid, ((10,10), (10,10)), 'constant') 

>> np.shape(grid) 
(25, 25) 

random_pts = np.random.random(size=(100, 2)) * len(grid) 

def inside(input): 
    idx = np.floor(input).astype(np.int) 
    mask = grid[idx[:,0], idx[:,1]] == 1 
    return input[mask] 

>> inside(random_pts) 
array([[ 10.59441506, 11.37998288], 
     [ 10.39124766, 13.27615815], 
     [ 12.28225713, 10.6970708 ], 
     [ 13.78351949, 12.9933591 ]]) 

但现在我想同时产生n套random_pts,并保持满足同样功能的条件n相应的子集的能力。所以,如果n=3

random_pts = np.random.random(size=(3, 100, 2)) * len(grid) 

没有求助于for循环,我怎么可能指数我的变量,使得inside(random_pts)回报像

array([[[ 17.73323523, 9.81956681], 
     [ 10.97074592, 2.19671642], 
     [ 21.12081044, 12.80412997]], 

     [[ 11.41995519, 2.60974757]], 

     [[ 9.89827156, 9.74580059], 
     [ 17.35840479, 7.76972241]]]) 
+0

只是好奇 - 做了发布的解决方案为您工作? – Divakar

+0

这是功能,但不是很实用;当函数需要迭代调用时,额外的数组操作会导致性能下降。我希望有一个更直接的方法来切分输入,可以达到相同的结果。 –

+1

因此,对于问题中的发布数据,您希望有三个独立的数组,对吗?为了获得这样的单独数组,'np.split'是最为人所知的方法,但速度很慢,因为存在一个分裂多数组操作。我认为这是这里最慢的部分。如果您对非分割输出没问题,则发布的解决方案中的“out_cat_array”可能是您的输出。 – Divakar

回答

1

一种方法 -

def inside3d(input): 
    # Get idx in 3D 
    idx3d = np.floor(input).astype(np.int) 

    # Create a similar mask as witrh 2D case, but in 3D now 
    mask3d = grid[idx3d[:,:,0], idx3d[:,:,1]]==1 

    # Count of mask matches for each index in 0th dim  
    counts = np.sum(mask3d,axis=1) 

    # Index into input to get masked matches across all elements in 0th dim 
    out_cat_array = input.reshape(-1,2)[mask3d.ravel()] 

    # Split the rows based on the counts, as the final output 
    return np.split(out_cat_array,counts.cumsum()[:-1]) 

验证结果 -

创建3D随机输入:

In [91]: random_pts3d = np.random.random(size=(3, 100, 2)) * len(grid) 

随着inside3d:

In [92]: inside3d(random_pts3d) 
Out[92]: 
[array([[ 10.71196268, 12.9875877 ], 
     [ 10.29700184, 10.00506662], 
     [ 13.80111411, 14.80514828], 
     [ 12.55070282, 14.63155383]]), array([[ 10.42636137, 12.45736944], 
     [ 11.26682474, 13.01632751], 
     [ 13.23550598, 10.99431284], 
     [ 14.86871413, 14.19079225], 
     [ 10.61103434, 14.95970597]]), array([[ 13.67395756, 10.17229061], 
     [ 10.01518846, 14.95480515], 
     [ 12.18167251, 12.62880968], 
     [ 11.27861513, 14.45609646], 
     [ 10.895685 , 13.35214678], 
     [ 13.42690335, 13.67224414]])] 

随着内:

In [93]: inside(random_pts3d[0]) 
Out[93]: 
array([[ 10.71196268, 12.9875877 ], 
     [ 10.29700184, 10.00506662], 
     [ 13.80111411, 14.80514828], 
     [ 12.55070282, 14.63155383]]) 

In [94]: inside(random_pts3d[1]) 
Out[94]: 
array([[ 10.42636137, 12.45736944], 
     [ 11.26682474, 13.01632751], 
     [ 13.23550598, 10.99431284], 
     [ 14.86871413, 14.19079225], 
     [ 10.61103434, 14.95970597]]) 

In [95]: inside(random_pts3d[2]) 
Out[95]: 
array([[ 13.67395756, 10.17229061], 
     [ 10.01518846, 14.95480515], 
     [ 12.18167251, 12.62880968], 
     [ 11.27861513, 14.45609646], 
     [ 10.895685 , 13.35214678], 
     [ 13.42690335, 13.67224414]])