2017-11-17 156 views
0

所以我有两个矩阵,A和B,我想计算这里给出的最小加乘积:Min-plus matrix multiplication。为此,我实施了以下操作:如何使python中的Min-plus矩阵乘法更快?

def min_plus_product(A,B): 
    B = np.transpose(B) 
    Y = np.zeros((len(B),len(A))) 
    for i in range(len(B)): 
     Y[i] = (A + B[i]).min(1) 
    return np.transpose(Y) 

这工作正常,但对大矩阵很慢,有没有办法使其更快?我听说在C或使用GPU实现可能是不错的选择。

+1

您的输入有多大? 'min_plus_product'的速度与'dot'的普通矩阵乘法的速度相比如何? – user2357112

+0

*“...对于大矩阵很慢”*与@ user2357112同样的问题:矩阵的典型大小是多少? –

+0

我测试了A10000和B10020000,得到了135s的矿,1.95s的np.dot。有时我的矩阵可能会比这更大。 – Rael

回答

2

这是一个算法,如果中间维度足够大并且条目均匀分布,则会节省一点。它利用了这样一个事实,即最小的总和通常是来自两个小项。

import numpy as np 

def min_plus_product(A,B): 
    B = np.transpose(B) 
    Y = np.zeros((len(B),len(A))) 
    for i in range(len(B)): 
     Y[i] = (A + B[i]).min(1) 
    return np.transpose(Y) 


def min_plus_product_opt(A,B, chop=None): 
    if chop is None: 
     # not sure this is optimal 
     chop = int(np.ceil(np.sqrt(A.shape[1]))) 
    B = np.transpose(B) 
    Amin = A.min(1) 
    Y = np.zeros((len(B),len(A))) 
    for i in range(len(B)): 
     o = np.argsort(B[i]) 
     Y[i] = (A[:, o[:chop]] + B[i, o[:chop]]).min(1) 
     if chop < len(o): 
      idx = np.where(Amin + B[i, o[chop]] < Y[i])[0] 
      for j in range(chop, len(o), chop): 
       if len(idx) == 0: 
        break 
       x, y = np.ix_(idx, o[j : j + chop]) 
       slmin = (A[x, y] + B[i, o[j : j + chop]]).min(1) 
       slmin = np.minimum(Y[i, idx], slmin) 
       Y[i, idx] = slmin 
       nidx = np.where(Amin[idx] + B[i, o[j + chop]] < Y[i, idx])[0] 
       idx = idx[nidx] 
    return np.transpose(Y) 

A = np.random.random(size=(1000,1000)) 
B = np.random.random(size=(1000,2000)) 

print(np.allclose(min_plus_product(A,B), min_plus_product_opt(A,B))) 

import time 
t = time.time();min_plus_product(A,B);print('naive {}sec'.format(time.time()-t)) 
t = time.time();min_plus_product_opt(A,B);print('opt {}sec'.format(time.time()-t)) 

示例输出:

True 
naive 7.794037580490112sec 
opt 1.65810227394104sec 
+0

@Rael感谢您的接受。如果你还在考虑去C语言(我推荐Cython,根据我的经验,它可以轻松地与numpy进行接口,并且比它的版本号可能提供的更成熟),在这种情况下,代码可以大大简化。基本上你不需要大块,也不需要屏蔽和重新索引,而只需循环“a”行,并一行一行地短路。 –

0

一种可能的简单途径是使用numba。上1000×1000甲

from numba import autojit 
import numpy as np 
@autojit(nopython=True) 
def min_plus_product(A,B): 
    n = A.shape[0] 
    C = np.zeros((n,n)) 
    for i in range(n): 
     for j in range(n): 
      minimum = A[i,0]+B[0,j] 
      for k in range(1,n): 
       minimum = min(A[i,k]+B[k,j],minimum) 
      C[i,j] = minimum 
    return C 

计时,B矩阵是:

1个循环,最好的3:每次循环4.28 S为原代码

1个循环,最好的3:每次循环2.32小号对于numba代码