2010-08-19 42 views
10

几年前,有人posted活动状态食谱作比较,三个python/NumPy函数;其中每个接受相同的参数并返回相同的结果,距离矩阵为什么在这里循环跳动索引?

其中两个是从公开资料中获得的;他们都是 - 或者他们似乎是我惯用的numpy代码。创建距离矩阵所需的重复计算由numpy的优雅索引语法驱动。这里是其中的一个:

from numpy.matlib import repmat, repeat 

def calcDistanceMatrixFastEuclidean(points): 
    numPoints = len(points) 
    distMat = sqrt(sum((repmat(points, numPoints, 1) - 
      repeat(points, numPoints, axis=0))**2, axis=1)) 
    return distMat.reshape((numPoints,numPoints)) 

使用一个循环(这显然是一个很大的循环考虑到的只是1000的2D点的距离矩阵,有一万个条目)第三创建距离矩阵。乍一看,这个函数在我看来像我在学习NumPy时编写的代码,我会先编写NumPy代码,然后逐行翻译它。

活动状态帖子发布几个月后,比较三者的性能测试结果在NumPy邮件列表上发布并在thread中讨论。

与事实上的循环功能显著跑赢另外两个:在线程

from numpy import mat, zeros, newaxis 

def calcDistanceMatrixFastEuclidean2(nDimPoints): 
    nDimPoints = array(nDimPoints) 
    n,m = nDimPoints.shape 
    delta = zeros((n,n),'d') 
    for d in xrange(m): 
    data = nDimPoints[:,d] 
    delta += (data - data[:,newaxis])**2 
    return sqrt(delta) 

一位与会者(凯尔·Mierle)提供一个理由,这可能是真实的:

我怀疑这会更快的原因是 它具有更好的地方性,完全完成一个相对较小的工作集上的计算,然后再转到下一个工作集之一。一行 必须重复将可能较大的MxN阵列拉入处理器。

通过这张海报自己的帐户,他的评论只是一个怀疑,似乎并没有进一步讨论。

有关如何解释这些结果的其他想法?

特别是,有没有一个有用的规则 - 关于什么时候循环和何时索引 - 可以从这个例子中提取作为编写numpy代码的指导?

对于那些不熟悉NumPy的人,或者没有看过代码的人,这种比较不是基于边缘案例 - 如果是的话,这对我来说肯定不会那么有趣。相反,这种比较涉及在矩阵计算中执行共同任务的功能(即,创建给定两个前件的结果数组)。而且,每个函数都是由最常见的numpy内建插件组成的。

回答

11

TL; DR上面的第二个代码仅在点的维数上循环(对于3D点,通过for循环的次数为3次),因此循环并不多。上面第二个代码中真正的加速是,它更好地利用了Numpy的力量,以避免在找到点之间的差异时创建一些额外的矩阵。这减少了使用的内存和计算量。

更长解释 我认为calcDistanceMatrixFastEuclidean2函数正在欺骗你的循环或许。它仅循环点的维数。对于1D点,循环只执行一次,对于2D,两次,对于3D,则执行三次。这实际上并没有太多循环。

让我们分析一下代码,看看为什么这个代码比另一个更快。 calcDistanceMatrixFastEuclidean我会打电话fast1calcDistanceMatrixFastEuclidean2fast2

fast1是基于Matlab的做事方式,如repmap函数所证明的那样。在这种情况下,repmap函数会创建一个数组,它只是原来的数据一遍又一遍地重复。但是,如果您查看该函数的代码,则效率非常低。它使用许多Numpy功能(3 reshape s和2 repeat s)来执行此操作。 repeat函数也用于创建一个包含原始数据的数组,每个数据项重复多次。如果我们的输入数据是[1,2,3],那么我们从[1,1,1,2,2,2,3,3,3]减去[1,2,3,1,2,3,1,2,3]。 Numpy必须在运行Numpy的C代码之间创建大量额外的矩阵,而这些代码本可以避免。

fast2使用更多的Numpy的繁重工作,而不会在Numpy调用之间创建尽可能多的矩阵。 fast2循环通过点的每个维度,进行减法并保持每个维度之间的平方差的总计。只有最后才是平方根。到目前为止,这可能听起来不如fast1那样有效,但fast2通过使用Numpy的索引避免了做repmat的东西。为简单起见,我们来看一维情况。 fast2制作数据的一维数组,并从数据的2D(N×1)数组中减去它。这将创建每个点与所有其他点之间的差异矩阵,而不必使用repmatrepeat,从而绕过创建大量额外数组。这是真正的速度差异在我看来。 fast1在矩阵之间创建了许多额外的内容(并且它们的计算开销很大)以找到点之间的差异,而fast2更好地利用了Numpy的力量来避免这些差异。

顺便说一句,这里是一个有点快的fast2版本:

def calcDistanceMatrixFastEuclidean3(nDimPoints): 
    nDimPoints = array(nDimPoints) 
    n,m = nDimPoints.shape 
    data = nDimPoints[:,0] 
    delta = (data - data[:,newaxis])**2 
    for d in xrange(1,m): 
    data = nDimPoints[:,d] 
    delta += (data - data[:,newaxis])**2 
    return sqrt(delta) 

不同的是,我们不再产生增量的零矩阵。

+0

非常有帮助,谢谢。从我+1。 – doug 2010-08-19 07:31:20

1

dis的乐趣:

dis.dis(calcDistanceMatrixFastEuclidean)

2   0 LOAD_GLOBAL    0 (len) 
       3 LOAD_FAST    0 (points) 
       6 CALL_FUNCTION   1 
       9 STORE_FAST    1 (numPoints) 

    3   12 LOAD_GLOBAL    1 (sqrt) 
      15 LOAD_GLOBAL    2 (sum) 
      18 LOAD_GLOBAL    3 (repmat) 
      21 LOAD_FAST    0 (points) 
      24 LOAD_FAST    1 (numPoints) 
      27 LOAD_CONST    1 (1) 
      30 CALL_FUNCTION   3 

    4   33 LOAD_GLOBAL    4 (repeat) 
      36 LOAD_FAST    0 (points) 
      39 LOAD_FAST    1 (numPoints) 
      42 LOAD_CONST    2 ('axis') 
      45 LOAD_CONST    3 (0) 
      48 CALL_FUNCTION   258 
      51 BINARY_SUBTRACT 
      52 LOAD_CONST    4 (2) 
      55 BINARY_POWER 
      56 LOAD_CONST    2 ('axis') 
      59 LOAD_CONST    1 (1) 
      62 CALL_FUNCTION   257 
      65 CALL_FUNCTION   1 
      68 STORE_FAST    2 (distMat) 

    5   71 LOAD_FAST    2 (distMat) 
      74 LOAD_ATTR    5 (reshape) 
      77 LOAD_FAST    1 (numPoints) 
      80 LOAD_FAST    1 (numPoints) 
      83 BUILD_TUPLE    2 
      86 CALL_FUNCTION   1 
      89 RETURN_VALUE 

dis.dis(calcDistanceMatrixFastEuclidean2)

2   0 LOAD_GLOBAL    0 (array) 
       3 LOAD_FAST    0 (nDimPoints) 
       6 CALL_FUNCTION   1 
       9 STORE_FAST    0 (nDimPoints) 

    3   12 LOAD_FAST    0 (nDimPoints) 
      15 LOAD_ATTR    1 (shape) 
      18 UNPACK_SEQUENCE   2 
      21 STORE_FAST    1 (n) 
      24 STORE_FAST    2 (m) 

    4   27 LOAD_GLOBAL    2 (zeros) 
      30 LOAD_FAST    1 (n) 
      33 LOAD_FAST    1 (n) 
      36 BUILD_TUPLE    2 
      39 LOAD_CONST    1 ('d') 
      42 CALL_FUNCTION   2 
      45 STORE_FAST    3 (delta) 

    5   48 SETUP_LOOP    76 (to 127) 
      51 LOAD_GLOBAL    3 (xrange) 
      54 LOAD_FAST    2 (m) 
      57 CALL_FUNCTION   1 
      60 GET_ITER 
     >> 61 FOR_ITER    62 (to 126) 
      64 STORE_FAST    4 (d) 

    6   67 LOAD_FAST    0 (nDimPoints) 
      70 LOAD_CONST    0 (None) 
      73 LOAD_CONST    0 (None) 
      76 BUILD_SLICE    2 
      79 LOAD_FAST    4 (d) 
      82 BUILD_TUPLE    2 
      85 BINARY_SUBSCR 
      86 STORE_FAST    5 (data) 

    7   89 LOAD_FAST    3 (delta) 
      92 LOAD_FAST    5 (data) 
      95 LOAD_FAST    5 (data) 
      98 LOAD_CONST    0 (None) 
      101 LOAD_CONST    0 (None) 
      104 BUILD_SLICE    2 
      107 LOAD_GLOBAL    4 (newaxis) 
      110 BUILD_TUPLE    2 
      113 BINARY_SUBSCR 
      114 BINARY_SUBTRACT 
      115 LOAD_CONST    2 (2) 
      118 BINARY_POWER 
      119 INPLACE_ADD 
      120 STORE_FAST    3 (delta) 
      123 JUMP_ABSOLUTE   61 
     >> 126 POP_BLOCK 

    8  >> 127 LOAD_GLOBAL    5 (sqrt) 
      130 LOAD_FAST    3 (delta) 
      133 CALL_FUNCTION   1 
      136 RETURN_VALUE 

我不是dis的专家,但好像你不得不看更多在f第一次打电话告诉他们为什么需要一段时间。还有一个使用Python的性能分析工具,cProfile

+1

如果您使用[cProfile](http://docs.python.org/library/profile.html#instant-user-s-manual),我建议使用[RunSnakeRun](http:// www。 vrplumber.com/programming/runsnakerun/)查看结果。 – detly 2010-08-19 04:25:17

+0

我注意到,Python优化的技巧似乎通常是让Python解释器尽可能少地执行Python指令。 – Omnifarious 2011-02-26 03:32:17