2016-08-23 65 views
3

我试图向量化某个加权总和,但无法弄清楚如何去做。我在下面创建了一个简单的最小工作示例。我猜这个解决方案涉及到bsxfun或重塑和克罗内克产品,但我仍然没有设法使它工作。三重加权总和

rng(1); 
N = 200; 
T1 = 5; 
T2 = 7; 
T3 = 10; 


A = rand(N,T1,T2,T3); 
w1 = rand(T1,1); 
w2 = rand(T2,1); 
w3 = rand(T3,1); 

B = zeros(N,1); 

for i = 1:N 
for j1=1:T1 
    for j2=1:T2 
    for j3=1:T3 
    B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3); 
    end 
    end 
end 
end 

A = B; 

对于二维情况,有一个智能答案here

+0

你需要推广?因为如果是这样的话,我会把你的N,T1,T2,T3换成一个数组。 –

+0

我其实只是想要三维的情况。但泛化可能对其他人有用:) – phdstudent

+0

以下概括:) –

回答

5

您可以使用额外的乘法修改前一个答案的w1 * w2'网格,然后再乘以w3。然后,您可以再次使用矩阵乘法与A的“拼合”版本相乘。

W = reshape(w1 * w2.', [], 1) * w3.'; 
B = reshape(A, size(A, 1), []) * W(:); 

你可以换权的创建到它自己的功能,使这个推广到N权重。由于这使用递归,因此N仅限于当前的递归限制(默认为500)。

function W = createWeights(W, varargin) 
    if numel(varargin) > 0 
     W = createWeights(W(:) * varargin{1}(:).', varargin{2:end}); 
    end 
end 

而且随着使用它:

W = createWeights(w1, w2, w3); 
B = reshape(A, size(A, 1), []) * W(:); 

更新

使用的@ CKT的很好的建议,使用kron一部分,我们可以修改createWeights只是一点点。

function W = createWeights(W, varargin) 
    if numel(varargin) > 0 
     W = createWeights(kron(varargin{1}, W), varargin{2:end}); 
    end 
end 
+0

@Suever基准顶部! :) –

1

这是一个道理:

ww1 = repmat (permute (w1, [4, 1, 2, 3]), [N, 1, T2, T3]); 
ww2 = repmat (permute (w2, [3, 4, 1, 2]), [N, T1, 1, T3]); 
ww3 = repmat (permute (w3, [2, 3, 4, 1]), [N, T1, T2, 1 ]); 

B = ww1 .* ww2 .* ww3 .* A; 
B = sum (B(:,:), 2) 

您可以通过在首位适当的尺寸创建w1w2w3避免permute。此外,您可以使用bsxfun而不是repmat来获得额外的性能,我只是在此处显示逻辑,而repmat更容易遵循。

编辑:广义版本任意输入尺寸:

Dims = {N, T1, T2, T3}; % add T4, T5, T6, etc as appropriate 
Params = cell (1, length (Dims)); 

Params{1} = rand (Dims{:}); 
for n = 2 : length (Dims) 
    DimSubscripts = ones (1, length (Dims)); DimSubscripts(n) = Dims{n}; 
    RepSubscripts = [Dims{:}]; RepSubscripts(n) = 1; 
    Params{n} = repmat (rand (DimSubscripts), RepSubscripts); 
end 

B = times (Params{:}); 
B = sum (B(:,:), 2) 
1

同样,你不能概括这一点,以及对ND,除非你做了一些功能来构造克罗内克产品载体,但如何

A = reshape(A, N, []) * kron(w3, kron(w2, w1)); 
1

如果我们想反正有功能的途径,并偏袒优雅/简洁的性能,然后再考虑这一点:

function B = weightReduce(A, varargin) 

    B = A; 
    for i = length(varargin):-1:1 
     N = length(varargin{i}); 
     B = reshape(B, [], N) * varargin{i}; 
    end 

end 

这是性能的比较,我看到:

tic; 
for i = 1:10000 
    W = createWeights(w1,w2,w3); 
    B = reshape(A, size(A,1), [])*W(:); 
end 
toc 
Elapsed time is 0.920821 seconds. 
tic; 
for i = 1:10000 
    B2 = weightReduce(A, w1, w2, w3); 
end 
toc 
Elapsed time is 0.484470 seconds.