2017-06-03 101 views
0

我有一个函数,我想用numba进行编译,但是我需要计算该函数内部的阶乘。不幸的是numba不支持math.factorial在numba nopython函数中计算阶乘的最快方法

import math 
import numba as nb 

@nb.njit 
def factorial1(x): 
    return math.factorial(x) 

factorial1(10) 
# UntypedAttributeError: Failed at nopython (nopython frontend) 

我看到,它支持math.gamma(可以用来计算阶乘)表示“整数值,但是违背了真正的math.gamma功能它没有返回花车“:

@nb.njit 
def factorial2(x): 
    return math.gamma(x+1) 

factorial2(10) 
# 3628799.9999999995 <-- not exact 

math.gamma(11) 
# 3628800.0 <-- exact 

,它的缓慢相比math.factorial

%timeit factorial2(10) 
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 
%timeit math.factorial(10) 
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

所以我决定定义自己的功能:

@nb.njit 
def factorial3(x): 
    n = 1 
    for i in range(2, x+1): 
     n *= i 
    return n 

factorial3(10) 
# 3628800 

%timeit factorial3(10) 
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) 

它仍然比math.factorial慢,但是它比基于math.gamma功能numba更快的值是“精确”。

所以我正在寻找最快的方法来计算一个正整数(< = 20;为了避免溢出)factorial在nopython的numba函数中。

+2

如果你只关心整数“0..20”的阶乘因子,那么查找表可能值得检查速度。 –

+0

Arrrgggh,在我以前的评论中,我写了*你的*我应该写*你是*。或*如果您唯一的担心是...... * –

+0

您可以尝试重新实现numba中的python方法 - 它会通过一些额外的步骤来以特定方式对乘法进行排序 - https://github.com/python/ cpython/blob/3.6/Modules/mathmodule.c#L1275 – chrisb

回答

1

对于值< = 20,python正在使用查找表,正如评论中所建议的那样。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([ 
    1, 1, 2, 6, 24, 120, 720, 5040, 40320, 
    362880, 3628800, 39916800, 479001600, 
    6227020800, 87178291200, 1307674368000, 
    20922789888000, 355687428096000, 6402373705728000, 
    121645100408832000, 2432902008176640000], dtype='int64') 

@nb.jit 
def fast_factorial(n): 
    if n > 20: 
     raise ValueError 
    return LOOKUP_TABLE[n] 

从Python中叫它比Python版本稍慢由于numba调度开销。

In [58]: %timeit math.factorial(10) 
10000000 loops, best of 3: 79.4 ns per loop 

In [59]: %timeit fast_factorial(10) 
10000000 loops, best of 3: 173 ns per loop 

但是在另一个numba函数中调用可以更快。

def loop_python(): 
    for i in range(10000): 
     for n in range(21): 
      math.factorial(n) 

@nb.njit 
def loop_numba(): 
    for i in range(10000): 
     for n in range(21): 
      fast_factorial(n) 

In [65]: %timeit loop_python() 
10 loops, best of 3: 36.7 ms per loop 

In [66]: %timeit loop_numba() 
10000000 loops, best of 3: 73.6 ns per loop 
+0

Numba做了积极的循环优化,所以如果你不保存'fast_factorial'的结果,它甚至不会执行循环。 – MSeifert