2015-10-14 76 views
0

我在模拟中有一些复杂的赋值逻辑,我想优化性能。当前的逻辑被实现为一组嵌套for循环遍历各种numpy数组。我想这个矢量化分配逻辑,但一直没能弄清楚这是否可能向量化numpy中的复杂赋值逻辑

import numpy as np 
from itertools import izip 

def reverse_enumerate(l): 
    return izip(xrange(len(l)-1, -1, -1), reversed(l)) 

materials = np.array([[1, 0, 1, 1], 
        [1, 1, 0, 0], 
        [0, 1, 1, 1], 
        [1, 0, 0, 1]]) 

vectors = np.array([[1, 1, 0, 0], 
        [0, 0, 1, 1]]) 

prices = np.array([10, 20, 30, 40]) 
demands = np.array([1, 1, 1, 1]) 

supply_by_vector = np.zeros(len(vectors)).astype(int) 

#go through each material and assign it to the first vector that the material covers 
for m_indx, material in enumerate(materials): 
    #find the first vector where the material covers the SKU 
    for v_indx, vector in enumerate(vectors): 
     if (vector <= material).all(): 
      supply_by_vector[v_indx] = supply_by_vector[v_indx] + 1 
      break 

original_supply_by_vector = np.copy(supply_by_vector) 
profit_by_vector = np.zeros(len(vectors)) 
remaining_ask_by_sku = np.copy(demands) 

#calculate profit by assigning material from vectors to SKUs to satisfy demand 
#go through vectors in reverse order (so lowest priority vectors are used up first) 
profit = 0.0 
for v_indx, vector in reverse_enumerate(vectors): 
    for sku_indx, price in enumerate(prices): 
     available = supply_by_vector[v_indx] 
     if available == 0: 
      continue 

     ask = remaining_ask_by_sku[sku_indx] 
     if ask <= 0: 
      continue 

     if vector[sku_indx]: 
      assign = ask if available > ask else available 
      remaining_ask_by_sku[sku_indx] = remaining_ask_by_sku[sku_indx] - assign 
      supply_by_vector[v_indx] = supply_by_vector[v_indx] - assign 

      profit_by_vector[v_indx] = profit_by_vector[v_indx] + assign*price 
      profit = profit + assign * price 

print 'total profit:', profit 
print 'unfulfilled demand:', remaining_ask_by_sku 
print 'original supply:', original_supply_by_vector 

结果:

total profit: 80.0 
unfulfilled demand: [0 1 0 0] 
original supply: [1 2] 
+1

滚动代码块可以阻止偶然的读者。 – hpaulj

回答

0

似乎有最里面的内迭代之间的依赖关系在嵌套循环的第二部分/组中嵌套循环,对我来说,如果不是无法进行矢量化,看起来很困难。这样,此信息基本上是试图代替进行矢量两个嵌套循环的第一组的部分解决方案,这是 -

supply_by_vector = np.zeros(len(vectors)).astype(int) 
for m_indx, material in enumerate(materials): 
    #find the first vector where the material covers the SKU 
    for v_indx, vector in enumerate(vectors): 
     if (vector <= material).all(): 
      supply_by_vector[v_indx] = supply_by_vector[v_indx] + 1 
      break 

即整个部分可以通过的向量化一行代码被替换,像​​这样 -

supply_by_vector = ((vectors[:,None] <= materials).all(2)).sum(1)