2017-04-19 83 views
1

我正在优化我的代码的瓶颈部分 - 迭代函数a'= f(a),其中a和a'是N乘1的向量,直到max(abs(a' - a))足够小。可以在numpy.max(numpy.abs(a-b))上使用Cython/Numba编译函数吗?

我已经把对F(A)一Numba包装,并得到了最优化的纯NumPy的版本,我能够produe一个不错的加速(切运行约50%)。

我试着写numpy.max的C-兼容版本(numpy.abs(aprime - 一)),但事实证明这是慢!我实际上失去了我从Numba获得的所有收益 - 追赶迭代的第一部分!

有可能是一种方式,Numba或用Cython是改进numpy.max(numpy.abs(aprime - 一))?我重现我的代码以供参考,其中a是P0和'是Pprime:

编辑:对我来说,它似乎是“扁平()”输入到“maxabs()”的重要。当我这样做时,表现并不比NumPy差。然后,当我按照JoshAdel的建议在定时括号外做一个“干运行”时,带有“maxabs”的循环比带有numpy.max(numpy.abs())的循环略好一些。

from numba import jit 
import numpy as np 

### Preliminaries, to make the working example fully functional 

n = 1200 
Gammer = np.exp(-np.random.rand(n,n)) 

alpher = np.ones((n,1)) 
xxer = 10000*np.random.rand(n,1) 

chii = 6.5 
varkappa = 6.5 
phi3 = 1.5 
A = .5 
sig = .2 

mmer = np.dot(Gammer,xxer**phi3) 


totalprod = A*alpher + (1-A)*mmer 
Gammerchii = Gammer**chii 
Gammerrats = Gammerchii[:,0].flatten()/Gammerchii[0,:].flatten() 
Gammerrats[(Gammerchii[0,:].flatten() == 0) | (Gammerchii[:,0].flatten() == 0)] = 1. 
P0 = (Gammerrats*(xxer[0]/totalprod[0])*(totalprod/xxer).flatten())**(1/(1+2*chii)) 
P0 *= n/np.sum(P0) 
### End of preliminaries 

### This is the function to produce a' = f(a) 
@jit 
def Piteration(P0, chii, sig, n, xxer, totalprod, Gammerrats, Gammerchii): 
    Mac = np.zeros((n,)) 
    Pprime = np.zeros((n,)) 
    themacpow = 1-(1/chii)*(sig/(1-sig)) 
    specialchiipow = 1/(1+2*chii) 
    Psum = 0. 

    for i in range(n): 
     for j in range(n): 
      Mac[j] += ((P0[i]/P0[j])**chii)*Gammerchii[i,j]*totalprod[j] 

    for i in range(n): 
     Pprime[i] = (Gammerrats[i]*(xxer[0]/totalprod[0])*(totalprod[i]/xxer[i])*((Mac[i]/Mac[0])**themacpow))**specialchiipow 
     Psum += Pprime[i] 

    Psum = n/Psum 

    for i in range(n): 
     Pprime[i] *= Psum 

    return Pprime 

### This is the function to find max(abs(aprime - a)) 
@jit 
def maxabs(vec1,vec2,n): 
    themax = 0. 
    curdiff = 0. 
    for i in range(n): 
     curdiff = vec1[i] - vec2[i] 
     if curdiff < 0: 
      curdiff *= -1 
     if curdiff > themax: 
      themax = curdiff 
    return themax 

### This is the main loop 
diff = 1000. 
while diff > 1e-2: 
    Pprime = Piteration(P0.flatten(), chii, sig, n, xxer.flatten(), totalprod.flatten(), Gammerrats.flatten(), Gammerchii) 

    diff = maxabs(P0.flatten(),Pprime.flatten(),n) 
    P0 = 1.*Pprime 
+0

_“我试着写一个C兼容的版本[...]但事实证明这是慢” _ - 你能告诉我们这个实现? – Eric

+0

执行如上。 “C兼容”我的意思是它使用循环代替矢量化。 –

回答

0

当我计时您maxabs功能VS np.max(np.abs(vec1 - vec2))的形状(1200,)的阵列中,numba版本〜使用numba 0.32.0快2.6倍。

当时间代码,请务必在运行之前你的函数时一次,这样你就不会包括它采取JIT的代码,你只需支付第一次的时间。通常使用timeit并且多次运行照顾此。我不知道你是如何做的时间,虽然我看到使用maxabs与numpy的呼叫几乎没有什么区别,但大部分运行时似乎在呼叫Piteration

+0

我正在使用“time.time()”计时代码 - 记录每个while循环之前的时间并记录下后的时间,并采取差异。 –

+0

我发现的一件事对我有所帮助 - 如果我“.flatten()”maxabs的输入,它运行得更快。 –

+0

而且 - 我试过在定时括号外做maxabs()的“空运行”,这似乎“修复”了“问题”。通过“空转”来首先使用功能,Numba的记录时间稍好一些。但是,正如你发现的那样,收益并不大 - 低于10%。 –

相关问题