2017-10-18 89 views
3

我有这段在应用程序运行期间多次调用的代码。 它需要一个表示值的数组(数组array_value)。 这些应该在Zone_array中定义的区域中总结出来。 zone_ids表示zone_array中所有可能区域的列表。如何优化从另一个数组索引值的数组中求和值的numpy循环,其中值等于循环索引

它基本上是这样的:我得到了一张人口栅格地图,我想知道有多少人生活在区域地图的每个区域。

代码:

values = np.zeros(len(zone_ids)) 
for i in zone_ids: 
    values[i] = round(np.nansum(value_array[zone_array == i]), 2) 
return values 

的罪魁祸首似乎是for循环,但我还没有找到一个方法来消除它,并有相同的结果。

我尝试了与计数,但我没有成功。 使用numba jit也没有效果。

我想远离cython,因为此代码将用于没有cython支持的Qgis插件。

测试代码:

import numpy as np 


def fill_values(zone_array, value_array, zone_ids): 
    values = np.zeros(len(zone_ids)) 
    for i in zone_ids: 
     values[i] = round(np.nansum(value_array[zone_array == i]), 2) 
    return values 


def run(): 
    # 300 different zones 
    zone_ids = range(300) 
    # zone map with 300 zones 
    zone_array = (np.random.rand(2000, 2000) * 300).astype(int) 
    # value map from which we want the sum of values per zone (real map can have NaN values) 
    value_array = (np.random.rand(2000, 2000) * 10.) 
    value_array[5, 5] = np.NAN 
    fill_values(zone_array, value_array, zone_ids) 


if __name__ == '__main__': 
    run() 

1.92小号±每个环路17.5毫秒(平均值±标准偏差7点运行时,1个循环的每一个。)

随着bincount的执行由Divakar的建议:

203毫秒±15.2毫秒每环(平均±标准。开发7点运行,1环的每一个)

+0

的罪魁祸首不是for循环。相反,问题在于比较'zone_array == i'。对于每个zone_id'i',必须检查所有2000x2000 = 4e6的值是否等于“i”。 – Chickenmarkus

+0

如果我减少区域ID的数量我得到一个速度增加,所以for循环仍然涉及到性能问题。因为我没有别的选择,我知道没有做'zone_array ==我'我专注于循环。最好的是,我可以以某种方式使用'zone_array == zone_ids'并跳过循环。 –

+0

您可以使用'zone_array [:,:,] == zone_ids'广播比较,但仍然会在for循环中留下索引,并且不会提高性能。 – user2699

回答

1

随着bincount直接使用,你就必须NaNs在求和中。因此,您可以简单地将NaNs替换为zeros并使用bincount。这应该更快,是一个矢量化的解决方案。

因此,实现起来 -

val_nonan = np.where(np.isnan(value_array), 0, value_array) 
out = np.round(np.bincount(zone_array.ravel(), val_nonan.ravel()),2) 
+0

这适用于我的问题。非常感谢。我想我的帐号尝试在哪里被nan值弄乱。此外'values = out [zone_ids]'用于您想要区域子集结果的情况。 –

相关问题