2014-12-04 51 views
2

我正在学习AVX内在用法,问题是如何优化下面的代码。我将它移植到内部工作的方式,但我感觉它变得更容易和更高效。固有的代码优化提示

C++伪代码版本

float min_value = FLOAT_MAX; 
float result_p = 0; 
for loop 
{ 
    float u = .... 

    if(u > 0.0f || u < 1.0f) 
    continue; 

    float p = ... 
    float t = .... 

    if(t < min_value) 
    { 
    min_value = t; 
    result_p = p; 
    } 
} 

我这个优化,用下面的代码:

int resultMask = 0 
float min_value = FLOAT_MAX; 
float result_p = 0; 
for loop 
{ 
    __m256 u = .... 

    if(u.m256_f32[0] < 0.0f || u.m256_f32[0] > 1.0f) resultMask &= 0xFE; 
    if(u.m256_f32[1] < 0.0f || u.m256_f32[1] > 1.0f) resultMask &= 0xFD; 
    if(u.m256_f32[2] < 0.0f || u.m256_f32[2] > 1.0f) resultMask &= 0xFB; 
    if(u.m256_f32[3] < 0.0f || u.m256_f32[3] > 1.0f) resultMask &= 0xF7; 
    if(u.m256_f32[4] < 0.0f || u.m256_f32[4] > 1.0f) resultMask &= 0xEF; 
    if(u.m256_f32[5] < 0.0f || u.m256_f32[5] > 1.0f) resultMask &= 0xDF; 
    if(u.m256_f32[6] < 0.0f || u.m256_f32[6] > 1.0f) resultMask &= 0xBF; 
    if(u.m256_f32[7] < 0.0f || u.m256_f32[7] > 1.0f) resultMask &= 0x7F; 

    if(resultMask == 0) 
    continue; 

    __m256 p = ... 
    __m256 t = .... 

    if(resultMask & 0x01) if(t.m256_f32[0] < min_value) {min_value = t.m256_f32[0]; result_p = p.m256_f32[0];} 
    if(resultMask & 0x02) if(t.m256_f32[1] < min_value) {min_value = t.m256_f32[1]; result_p = p.m256_f32[1];} 
    if(resultMask & 0x04) if(t.m256_f32[2] < min_value) {min_value = t.m256_f32[2]; result_p = p.m256_f32[2];} 
    if(resultMask & 0x08) if(t.m256_f32[3] < min_value) {min_value = t.m256_f32[3]; result_p = p.m256_f32[3];} 
    if(resultMask & 0x10) if(t.m256_f32[4] < min_value) {min_value = t.m256_f32[4]; result_p = p.m256_f32[4];} 
    if(resultMask & 0x20) if(t.m256_f32[5] < min_value) {min_value = t.m256_f32[5]; result_p = p.m256_f32[5];} 
    if(resultMask & 0x40) if(t.m256_f32[6] < min_value) {min_value = t.m256_f32[6]; result_p = p.m256_f32[6];} 
    if(resultMask & 0x80) if(t.m256_f32[7] < min_value) {min_value = t.m256_f32[7]; result_p = p.m256_f32[7];} 
} 

所有的 “如果” 是丑陋的,但我不能找到另一种解决方案。有人知道如何改变这种情况?我无法真正相信这是可以做到的最好的。

Thx

+0

欢迎!这不是适合堆栈溢出的实际问题。对于代码评论使用适当的地方[codereview.stackexchange](http://codereview.stackexchange.com/) – aggsol 2014-12-04 12:49:26

+0

不,这不是最好的。尝试使用'_mm256_cmp_ps'在单条指令中执行8次比较。 (您需要一个用于下限,另一个用于上限) – 2014-12-04 14:44:15

回答

4

首先要尝试的是自动矢量化。要做到这一点,您需要启用自动矢量化和AVX与GCC gcc -O3 -mavx。但如果你真的想与内在函数来做到这一点,你可以尝试这样的事:

__m256 min_value8 = _mm256_set1_ps(FLT_MAX); 
__m256 result_p8 = _mm256_setzero_ps(); 
__m256 one  = _mm256_set1_ps(1.0); 

for(int i=0; i<n; i+=8) { 
    //__m256 u8 = _mm256_loadu_ps(&u[i]); 
    __m256 u8 = ... 
    __m256 t1 = _mm256_cmp_ps(u8, _mm256_setzero_ps(), 2); // u <= 0 
    __m256 t2 = _mm256_cmp_ps(one, u8, 2);     // 1 <= u 
    __m256 t3 = _mm256_or_ps(t1,t2); 
    if(_mm256_testz_ps(t3,t3)) continue; 
    //at least one entry in u8 has u<=0 or u>=1 
    __m256 p8 = ... 
    __m256 t8 = ... 

    __m256 mask = _mm256_cmp_ps(t8, min_value8, 1);  // t < min_value 
    //min_value8 = _mm256_or_ps(_mm256_and_ps(mask,t8), _mm256_andnot_ps(mask,min_value8)); 
    //result_p8 = _mm256_or_ps(_mm256_and_ps(mask,p8), _mm256_andnot_ps(mask,result_p8)); 
    min_value8 = _mm256_blendv_ps(min_value8, t8, mask); 
    result_p8 = _mm256_blendv_ps(result_p8, p8, mask); 
} 
float tmp1[8], tmp2[8]; 
_mm256_storeu_ps(tmp1, min_value8); 
_mm256_storeu_ps(tmp2, result_p8); 
float min_value = tmp1[0]; 
float result_p = tmp2[0]; 
for(int i=1; i<8; i++) { 
    if(tmp1[i]<min_value) { 
     min_value = tmp1[i]; 
     result_p = tmp2[i]; 
    } 
} 

这个假设的迭代是独立的,即是p8t8不依赖于min_value8

编辑:

我被下面的代码

__m256 mask = _mm256_cmp_ps(t8, min_value8, 1); 
min_value8 = _mm256_or_ps(_mm256_and_ps(mask,t8), _mm256_andnot_ps(mask,min_value8)); 
result_p8 = _mm256_or_ps(_mm256_and_ps(mask,p8), _mm256_andnot_ps(mask,result_p8)); 

一条线可以简化为困扰:

min_value8 = _mm256_min_ps(t8, min_value8); //probably faster 

但是,使用min在某种意义上重新计算的面具。更好的解决方案是与面膜混合

_m256 mask = _mm256_cmp_ps(t8, min_value8, 1); 
min_value8 = _mm256_blendv_ps(min_value8, t8, mask); 
result_p8 = _mm256_blendv_ps(result_p8, p8, mask); 
+1

自动矢量化不可能在循环体中使用'continue;'语句。也许如果最后的任务是有条件的。 – 2014-12-04 14:41:11

+0

@BenVoigt,在这种情况下,我很高兴我给出了一个解决方案(未经测试)使用内在:-) – 2014-12-04 14:43:35

+0

利用领域特定知识的奇迹...... – Mysticial 2014-12-04 19:26:32