2016-12-09 58 views
2

由于性能的原因,除了NumPy之外,我已经开始使用Numba了。我的Numba算法正在工作,但我有一种感觉,它应该更快。有一点是减缓它。以下是代码片段:在numba中的性能嵌套循环

@nb.njit 
def rfunc1(ws, a, l): 
    gn = a**l 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 
        if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
        numpy.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 
        if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
        numpy.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 

在我看来,if命令减缓下来。有没有更好的办法? (我试图在这里实现是与先前发布的问题是什么:Count possibilites for single crossoversws是尺寸含0(gn, l)的NumPy的阵列的和1

+0

你意识到这种规模可怕地与'gn'的大小...? –

+0

是的,l的最大大小是9,a总是2 – HighwayJohn

+0

你在Python 2还是3? –

回答

2

鉴于希望确保所有项目都是平等的逻辑,你可以利用这样一个事实,即如果有任何不相等的事实,则可以将计算短路(即停止比较)。我稍微修改原来的功能,以便(1)你不要重复相同的比较两次,和(2)值Y在所有的嵌套循环,从而有可能进行比较的回报:

@nb.njit 
def rfunc1(ws, a, l): 
    gn = a**l 
    ysum = 0 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 
        if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]): 
         y += 1 
         ysum += 1 

    return ysum 


@nb.njit 
def rfunc2(ws, a, l): 
    gn = a**l 
    ysum = 0 
    for x1 in range(gn): 
     for x2 in range(gn): 
      for x3 in range(gn): 
       y = 0.0 
       for i in range(1, l): 

        incr_y = True 
        for j in range(i): 
         if ws[x1,j] != ws[x2,j]: 
          incr_y = False 
          break 

        if incr_y is True: 
         for j in range(i,l): 
          if ws[x1,j] != ws[x3,j]: 
           incr_y = False 
           break 
        if incr_y is True: 
         y += 1 
         ysum += 1 
    return ysum 

我不知道完整的功能是什么样子,但希望这可以帮助你开始正确的道路。

现在对于一些计时:

l = 7 
a = 2 
gn = a**l 
ws = np.random.randint(0,2,size=(gn,l)) 
In [23]: 

%timeit rfunc1(ws, a , l) 
1 loop, best of 3: 2.11 s per loop 


%timeit rfunc2(ws, a , l) 
1 loop, best of 3: 39.9 ms per loop 

In [27]: rfunc1(ws, a , l) 
Out[27]: 131919 

In [30]: rfunc2(ws, a , l) 
Out[30]: 131919 

这就给了你50倍的加速。

+0

如何在'nopython = True'中使用'jit'? – pbreach

+0

'njit'相当于'jit(nopython = True)' – JoshAdel

+0

非常感谢! :) – HighwayJohn

2

而不只是“有感觉”在您的瓶颈,为什么不轮廓你的代码,并找到究竟在哪里呢?

性能分析的第一个目标是测试一个代表性的系统,以确定什么是缓慢的(或使用太多RAM,或导致太多的磁盘I/O或网络I/O)。

性能分析通常会增加开销(10倍到100倍的减速可能是典型的),您仍然希望尽可能使用类似于实际情况的代码。提取测试用例并隔离您需要测试的系统部分。最好是已经写入它自己的一组模块中。基本技术包括IPython中的%timeit魔术,time.time(),timing decorator(请参见下面的示例)。您可以使用这些技术来了解语句和函数的行为。

然后你有cProfile这将给你一个问题的高层次的看法,所以你可以引导你的注意力到关键的功能。

接下来,看看line_profiler,这将逐行分析您选择的功能。结果将包括每行被调用的次数以及每行所用时间的百分比。这正是您了解缓慢运行以及为什么需要的信息。

perf stat帮助您理解最终在CPU上执行的指令的数量以及CPU的高速缓存利用率。这允许对矩阵操作进行高级调整。

heapy可以跟踪Python内存中的所有对象。这非常适合寻找奇怪的内存泄漏。如果您使用的是长时间运行的系统,那么 然后dowser会引起您的兴趣:它允许您通过Web浏览器界面在长时间运行的过程中反思活动对象。

为了帮助您理解RAM使用率高的原因,请查看memory_profiler.这对于跟踪随时间推移的RAM使用情况特别有用,因为您可以向同事(或您自己)解释为什么某些功能使用的RAM多于预期。

例:定义一个装饰来自动定时测量

from functools import wraps 

def timefn(fn): 
    @wraps(fn) 
    def measure_time(*args, **kwargs): 
     t1 = time.time() 
     result = fn(*args, **kwargs) 
     t2 = time.time() 
     print ("@timefn:" + fn.func_name + " took " + str(t2 - t1) + " seconds") 
     return result 
    return measure_time 

@timefn 
def your_func(var1, var2): 
    ... 

有关详细信息,我建议读High performance Python(米莎戈雷利克;伊恩Ozsvald),从该上述被采购。

+0

这些都是很好的**一般**建议,但没有真正适用于这个问题。例如,你不能在numba函数内使用'line_profiler',也不能在'nopython'模式下调用'time.time'。最初的问题是关于改进一个函数的性能(大概已经确定为热点),这个函数是在numba中编码的。通常在那里,你必须对Numba可以转换成高性能的llvm代码有一个直觉,许多通用技术都无法解决这个问题。 – JoshAdel

+0

@JoshAdel:我想向OP建议,不要猜测瓶颈在哪里,但可以通过配置文件来确定。为了未来读者的利益,我试图使分析选项稍微完整(即使并非所有的都适用于OP的情况)。 – boardrider