这是一个算法,如果中间维度足够大并且条目均匀分布,则会节省一点。它利用了这样一个事实,即最小的总和通常是来自两个小项。
import numpy as np
def min_plus_product(A,B):
B = np.transpose(B)
Y = np.zeros((len(B),len(A)))
for i in range(len(B)):
Y[i] = (A + B[i]).min(1)
return np.transpose(Y)
def min_plus_product_opt(A,B, chop=None):
if chop is None:
# not sure this is optimal
chop = int(np.ceil(np.sqrt(A.shape[1])))
B = np.transpose(B)
Amin = A.min(1)
Y = np.zeros((len(B),len(A)))
for i in range(len(B)):
o = np.argsort(B[i])
Y[i] = (A[:, o[:chop]] + B[i, o[:chop]]).min(1)
if chop < len(o):
idx = np.where(Amin + B[i, o[chop]] < Y[i])[0]
for j in range(chop, len(o), chop):
if len(idx) == 0:
break
x, y = np.ix_(idx, o[j : j + chop])
slmin = (A[x, y] + B[i, o[j : j + chop]]).min(1)
slmin = np.minimum(Y[i, idx], slmin)
Y[i, idx] = slmin
nidx = np.where(Amin[idx] + B[i, o[j + chop]] < Y[i, idx])[0]
idx = idx[nidx]
return np.transpose(Y)
A = np.random.random(size=(1000,1000))
B = np.random.random(size=(1000,2000))
print(np.allclose(min_plus_product(A,B), min_plus_product_opt(A,B)))
import time
t = time.time();min_plus_product(A,B);print('naive {}sec'.format(time.time()-t))
t = time.time();min_plus_product_opt(A,B);print('opt {}sec'.format(time.time()-t))
示例输出:
True
naive 7.794037580490112sec
opt 1.65810227394104sec
您的输入有多大? 'min_plus_product'的速度与'dot'的普通矩阵乘法的速度相比如何? – user2357112
*“...对于大矩阵很慢”*与@ user2357112同样的问题:矩阵的典型大小是多少? –
我测试了A10000和B10020000,得到了135s的矿,1.95s的np.dot。有时我的矩阵可能会比这更大。 – Rael