我在下面写了一段代码,以了解多处理(MP)及其可能的速度增益与非MP版本的关系。除了突出显示的位置外,这两个函数几乎相同(对不起,并不知道突出显示代码区域的更好方法)。计时Python代码
该代码尝试识别数组列表(此处为1-D)中冗余条目的索引。由两个函数返回的id列表是相同的,但我的问题是关于时间差。正如你所看到的,我已经尝试了时间a)map函数,b)列表扩展和c)两种情况下的整个while循环。与非MP版本相比,MP在map
区域提供更好的加速,但其redun_ids.extend(...)
速度较慢。这实际上是迫使MP版本的整体速度增益下降。
有没有什么办法可以重写MP版本的redun_ids.extend(...)
部件以获得与非MP版本相同的时序?
#!/usr/bin/python
import multiprocessing as mproc
import sys
import numpy as np
import random
import time
def matdist(mats):
mat1 = mats[0]
mat2 = mats[1]
return np.allclose(mat1, mat2, rtol=1e-08, atol=1e-12)
def mp_remove_redundancy(larrays):
"""
remove_redundancy : identify arrays that are redundant in the
input list of arrays
"""
llen = len(larrays)
redun_ids = list()
templist = list()
i = 0
**pool = mproc.Pool(processes=10)**
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
#replication to create list of arrays
templist = templist*(llen-i-1)
**chunksize = len(templist)/10
if chunksize == 0:
chunksize = 1**
#clslist is a result object here
st=time.time()
**clslist = pool.map_async(matdist, zip(larrays[i+1:],
templist), chunksize)**
print 'map time:', time.time()-st
**outlist = clslist.get()[:]**
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(outlist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
del templist[:]
del outlist[:]
if i == (llen - 1):
break
print 'Time elapsed in MP:', time.time()-st1
pool.close()
pool.join()
del clslist
del templist
return redun_ids[:]
#######################################################
def remove_redundancy(larrays):
llen = len(larrays)
redun_ids = list()
clslist = list()
templist = list()
i = 0
st1=time.time()
while 1:
currarray = larrays[i]
if i not in redun_ids:
templist.append(currarray)
templist = templist*(llen-i-1)
st = time.time()
clslist = map(matdist, zip(larrays[i+1:],
templist))
print 'map time:', time.time()-st
#j+1+i gives the actual id num w.r.t to whole list
st=time.time()
redun_ids.extend([j+1+i for j, x in
enumerate(clslist) if x])
print 'Redun ids extend time:', time.time()-st
i = i + 1
#clear temp vars
del clslist[:]
del templist[:]
if i == (llen - 1):
break
print 'Tot non MP time:', time.time()-st1
del clslist
del templist
return redun_ids[:]
###################################################################
if __name__=='__main__':
if len(sys.argv) != 2:
sys.exit('# entries')
llen = int(sys.argv[1])
#generate random numbers between 1 and 10
mylist=[np.array([round(random.random()*9+1)]) for i in range(llen)]
print 'The input list'
print 'no MP'
rrlist = remove_redundancy(mylist)
print 'MP'
rrmplist = mp_remove_redundancy(mylist)
print 'Two lists match : {0}'.format(rrlist==rrmplist)
我想这解释了我在自己的答案中表达的疑惑。 :) – 2012-08-17 08:59:06