2016-03-30 30 views
3

有没有办法避免循环使代码更快?有没有办法避免循环使代码更快?

“var”是期望的结果。

“AA”和“BB”是其值已知的向量。

这段代码的四个主线是基于逻辑:000110110<=1>

for i=3:1:2000 

    for j=1:50011 
     if AA(i) == j && BB(i)<=BB(i-1) && BB(i-1)<=BB(i-2) 
      var(i)=j; 
     else 
      if AA(i) == j && BB(i)<=BB(i-1) && BB(i-1)>BB(i-2) 
       var(i)=j+50011; 
      else 
       if AA(i) == j && BB(i)>BB(i-1) && BB(i-1)<=BB(i-2) 
        var(i)=j+2*50011; 
       else 
        if AA(i) == j && BB(i)>BB(i-1) && BB(i-1)>BB(i-2) 
         var(i)=j+3*50011; 
        end 
       end 
      end 
     end 
    end 

end 

回答

3

我设法将您的代码矢量化为一行代码!从几秒钟下降到毫秒下:

out = [0 0 AA(3:end) + ([1 2] * (diff(BB(hankel(1:3, 3:numel(BB)))) > 0)).*50011]; 

要了解如何到达那里,让我们逐步提高了原代码。


1)

首先我们先从双循环您有:

tic 
var0 = zeros(size(AA)); 
for i=3:numel(AA) 
    for j=1:N 
     if AA(i) == j && BB(i)<=BB(i-1) && BB(i-1)<=BB(i-2) 
      var0(i)=j; 
     else 
      if AA(i) == j && BB(i)<=BB(i-1) && BB(i-1)>BB(i-2) 
       var0(i)=j+50011; 
      else 
       if AA(i) == j && BB(i)>BB(i-1) && BB(i-1)<=BB(i-2) 
        var0(i)=j+2*50011; 
       else 
        if AA(i) == j && BB(i)>BB(i-1) && BB(i-1)>BB(i-2) 
         var0(i)=j+3*50011; 
        end 
       end 
      end 
     end 
    end 
end 
toc 

2)

由于@SpamBot指出,嵌套if/else语句可以通过链接来简化它们。你也要多次评估相同的测试AA(i)==j。如果测试是错误的,整个for循环会被跳过。所以我们可以消除第二个for循环,并直接使用j=AA(i)

下面是新的代码:

tic 
var1 = zeros(size(AA)); 
for i=3:numel(AA) 
    j = AA(i); 
    if BB(i)<=BB(i-1) && BB(i-1)<=BB(i-2) 
     var1(i) = j; 
    elseif BB(i)<=BB(i-1) && BB(i-1)>BB(i-2) 
     var1(i) = j + 50011; 
    elseif BB(i)>BB(i-1) && BB(i-1)<=BB(i-2) 
     var1(i) = j + 2*50011; 
    elseif BB(i)>BB(i-1) && BB(i-1)>BB(i-2) 
     var1(i) = j + 3*50011; 
    end 
end 
toc 

这是一个巨大的进步,并且代码将在原来的一小部分时间运行。不过,我们可以做的更好......

3)

正如你在你的问题中提到的的if/else条件对应于模式00, 01, 10, 11其中0/1或假/真正执行的是二进制的结果x> y测试相邻的数字。

使用这个想法,我们可以得到下面的代码:

tic 
var2 = zeros(size(AA)); 
for i=3:numel(AA) 
    val = (BB(i) > BB(i-1)) * 10 + (BB(i-1) > BB(i-2)); 
    switch (val) 
     case 00 
      k = 0; 
     case 01 
      k = 50011; 
     case 10 
      k = 2*50011; 
     case 11 
      k = 3*50011; 
    end 
    var2(i) = AA(i) + k; 

end 
toc 

4)

让我们来替换switch语句用一个表查找操作。这给了我们这个新版本:

tic 
v = [0 1 2 3] * 50011; % 00 01 10 11 
var3 = zeros(size(AA)); 
for i=3:numel(AA) 
    var3(i) = AA(i) + v((BB(i) > BB(i-1))*2 + (BB(i-1) > BB(i-2)) + 1); 
end 
toc 

5)

在这最后的版本中,我们可以完全由提的是,每次迭代访问片BB(i-2:i)在滑动窗口的方式摆脱循环。我们可以在BB之上整齐地排列use the hankel function to create a sliding window(每个都作为列返回)。

接下来我们使用diff来执行向量化比较,然后将两个测试的结果0/1映射为[0 1 2 3]*50011值。最后,我们适当地添加矢量AA

这赋予我们最后一个内胆,完全矢量化:

tic 
var4 = [0, 0, AA(3:end) + ([1 2] * (diff(BB(hankel(1:3, 3:numel(BB)))) > 0)).*50011]; 
toc 

比较

为了验证上述解决方案,我用下面的随机向量作为测试数据:

N = 50011; 
AA = randi(N, [1 2000]); 
BB = randi(N, [1 2000]); 

assert(isequal(var0,var1,var2,var3,var4)) 

我得到以下时间匹配解决方案的顺序:

>> myscript % tested in MATLAB R2014a 
Elapsed time is 1.seconds. 
Elapsed time is 0.000111 seconds. 
Elapsed time is 0.000099 seconds. 
Elapsed time is 0.000089 seconds. 
Elapsed time is 0.000417 seconds. 

>> myscript % tested in MATLAB R2015b (with the new execution engine) 
Elapsed time is 2.816541 seconds. 
Elapsed time is 0.000233 seconds. 
Elapsed time is 0.000158 seconds. 
Elapsed time is 0.000157 seconds. 
Elapsed time is 0.000339 seconds. 

希望这篇文章不是太长,我只是想展示如何通过增量更改来解决这类问题。

现在你挑在该解决方案,您最喜欢:)

+0

非常感谢您的支持!但是,如果我理解你的比较,(除了第一个版本),最终vesion需要比其他三个更多的时间! – bzak

+0

我们在这里以毫秒级的顺序讨论,没有任何显着的区别......我建议你选择你最喜欢的解决方案(上面的#2完全没错)。我承认,在#3至#5中,重点在于获得比任何戏剧性能提升都更短且完全矢量化的解决方案。 – Amro

+0

您在MATLAB标签中看到,我们经常将这类问题作为挑战来编写尽可能短的矢量化代码,但这并不总是意味着最可读或最快的:) – Amro

2

我觉得最重要的改进是评估AA(i)==j循环j时只有一次。

此外,尽管只有最后一次覆盖是相关的,但您可能会经常覆盖var(i)。考虑只采取一个j = == AA(i)并且只为此做if-else。基本上,您正在AA中搜索j,请改为使用find(AA == j, 1)

作为一个方面说明,如果/ end块可以直接进入其他父块,那么就没有必要了。

1

这里的主要是利用logical indexing一个量化的方法 -

%// Get size parameter 
N = numel(AA)-2; 
limit = 50011; 

%// Get differentiation across BB 
dB = diff(BB); 

%// Construct an ID array with different valus for various conditions 
id = ones(N,1); 
id((dB(2:end) <= 0) & (dB(1:end-1) > 0)) = 2; 
id((dB(2:end) > 0) & (dB(1:end-1) <= 0)) = 3; 
id((dB(2:end) > 0) & (dB(1:end-1) > 0)) = 4; 

%// Get scaled values as used under various IF statements 
vals = ((id-1)*50011) + AA(3:end); 

%// Get a valid mask that would be used to set values from vals into output 
valid_mask = AA(3:end) <= limit; 

%// Setup output array and selectively set values from vals using valid_mask 
var_out = zeros(1,N); 
var_out(valid_mask) = vals(valid_mask); 

请注意,原始输出将有前两个元素一如既往零。提出的解决方案的输出会跳过前两个元素以避免冗余。如果需要与旧范式保持一致,请在开头填写两个零 -

final_out = [0 0 var_out]; 
相关问题