2015-03-13 184 views
3

我想在单个内核上快速地乘上矩阵。我浏览了网页并找到了几个算法,发现Strassen的算法是唯一的算法,这实际上是由人们实施的。我已经看过几个例子,并参考了下面的解决方案。我做了一个简单的基准,它会生成两个随机填充的500x500矩阵。斯特拉森算法耗时18秒,高中算法在0.4秒内完成。其他人在实现算法后非常有希望,那么我的错在哪里,我该如何让它更快?实现Strassen算法

// return C = A * B 
private Matrix strassenTimes(Matrix B, int LEAFSIZE) { 
    Matrix A = this; 
    if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions."); 

    if (N <= LEAFSIZE || M <= LEAFSIZE) { 
     return A.times(B); 
    } 

    // make new sub-matrices 
    int newAcols = (A.N + 1)/2; 
    int newArows = (A.M + 1)/2; 
    Matrix a11 = new Matrix(newArows, newAcols); 
    Matrix a12 = new Matrix(newArows, newAcols); 
    Matrix a21 = new Matrix(newArows, newAcols); 
    Matrix a22 = new Matrix(newArows, newAcols); 

    int newBcols = (B.N + 1)/2; 
    int newBrows = (B.M + 1)/2; 
    Matrix b11 = new Matrix(newBrows, newBcols); 
    Matrix b12 = new Matrix(newBrows, newBcols); 
    Matrix b21 = new Matrix(newBrows, newBcols); 
    Matrix b22 = new Matrix(newBrows, newBcols); 


    for (int i = 1; i <= newArows; i++) { 
     for (int j = 1; j <= newAcols; j++) { 
      a11.setElement(i, j, A.saveGet(i, j)); // top left 
      a12.setElement(i, j, A.saveGet(i, j + newAcols)); // top right 
      a21.setElement(i, j, A.saveGet(i + newArows, j)); // bottom left 
      a22.setElement(i, j, A.saveGet(i + newArows, j + newAcols)); // bottom right 
     } 
    } 

    for (int i = 1; i <= newBrows; i++) { 
     for (int j = 1; j <= newBcols; j++) { 
      b11.setElement(i, j, B.saveGet(i, j)); // top left 
      b12.setElement(i, j, B.saveGet(i, j + newBcols)); // top right 
      b21.setElement(i, j, B.saveGet(i + newBrows, j)); // bottom left 
      b22.setElement(i, j, B.saveGet(i + newBrows, j + newBcols)); // bottom right 
     } 
    } 

    Matrix aResult; 
    Matrix bResult; 

    aResult = a11.add(a22); 
    bResult = b11.add(b22); 
    Matrix p1 = aResult.strassenTimes(bResult, LEAFSIZE); 

    aResult = a21.add(a22); 
    Matrix p2 = aResult.strassenTimes(b11, LEAFSIZE); 

    bResult = b12.minus(b22); // b12 - b22 
    Matrix p3 = a11.strassenTimes(bResult, LEAFSIZE); 

    bResult = b21.minus(b11); // b21 - b11 
    Matrix p4 = a22.strassenTimes(bResult, LEAFSIZE); 

    aResult = a11.add(a12); // a11 + a12 
    Matrix p5 = aResult.strassenTimes(b22, LEAFSIZE); 

    aResult = a21.minus(a11); // a21 - a11 
    bResult = b11.add(b12); // b11 + b12 
    Matrix p6 = aResult.strassenTimes(bResult, LEAFSIZE); 

    aResult = a12.minus(a22); // a12 - a22 
    bResult = b21.add(b22); // b21 + b22 
    Matrix p7 = aResult.strassenTimes(bResult, LEAFSIZE); 

    Matrix c12 = p3.add(p5); // c12 = p3 + p5 
    Matrix c21 = p2.add(p4); // c21 = p2 + p4 

    aResult = p1.add(p4); // p1 + p4 
    bResult = aResult.add(p7); // p1 + p4 + p7 
    Matrix c11 = bResult.minus(p5); 

    aResult = p1.add(p3); // p1 + p3 
    bResult = aResult.add(p6); // p1 + p3 + p6 
    Matrix c22 = bResult.minus(p2); 

    // Grouping the results obtained in a single matrix: 
    int rows = c11.nrRows(); 
    int cols = c11.nrColumns(); 

    Matrix C = new Matrix(A.M, B.N); 
    for (int i = 1; i <= A.M; i++) { 
     for (int j = 1; j <= B.N; j++) { 
      int el; 
      if (i <= rows) { 
       if (j <= cols) { 
        el = c11.get(i, j); 
       } else { 
        el = c12.get(i, j - cols); 
       } 
      } else { 
       if (j <= cols) { 
        el = c21.get(i - rows, j); 
       } else { 
        el = c22.get(i - rows, j - rows); 
       } 
      } 
      C.setElement(i, j, el); 
     } 
    } 
    return C; 
} 

小基准具有下面的代码:

int AM, AN, BM, BN; 
AM = 500; 
AN = BM = 500; 
BN = 500; 
Matrix a = new Matrix(AM, AN); 
Matrix b = new Matrix(BM, BN); 

Random random = new Random(); 

for (int i = 1; i <= AM; i++) { 
    for (int j = 1; j <= AN; j++) { 
     a.setElement(i, j, random.nextInt(20)); 
    } 
} 
for (int i = 1; i <= BM; i++) { 
    for (int j = 1; j <= BN; j++) { 
     b.setElement(i, j, random.nextInt(20)); 
    } 
} 

System.out.println("strassen: A x B"); 
long tijd = System.currentTimeMillis(); 
Matrix c = a.strassenTimes(b); 
System.out.println("time = " + (System.currentTimeMillis() - tijd)); 

System.out.println("normal: A x B"); 
tijd = System.currentTimeMillis(); 
Matrix d = a.times(b); 
System.out.println("time = " + (System.currentTimeMillis() - tijd)); 

System.out.println("nr of different elements = " + c.compare(d)); 

结果如下:

strassen: A x B 
time = 18372 
normal: A x B 
time = 308 
nr of different elements = 0 

我知道这是一个代码低,但我想,如果你很开心大家帮帮我;)

编辑1: 为了完整起见,我添加了上面代码使用的一些方法。

public int get(int r, int c) { 
    if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) { 
     throw new ArrayIndexOutOfBoundsException("matrix is of size (" + 
       nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")"); 
    } 

    return content[r - 1][c - 1]; 
} 

private int saveGet(int r, int c) { 
    if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) { 
     return 0; 
    } 

    return content[r - 1][c - 1]; 
} 

public void setElement(int r, int c, int n) { 
    if (c > nrColumns() || r > nrRows() || c <= 0 || r <= 0) { 
     throw new ArrayIndexOutOfBoundsException("matrix is of size (" + 
       nrRows() + ", " + nrColumns() + "), but tries to set element(" + r + ", " + c + ")"); 
    } 
    content[r - 1][c - 1] = n; 
} 

// return C = A + B 
public Matrix add(Matrix B) { 
    Matrix A = this; 
    if (B.M != A.M || B.N != A.N) throw new RuntimeException("Illegal matrix dimensions."); 
    Matrix C = new Matrix(M, N); 
    for (int i = 0; i < M; i++) { 
     for (int j = 0; j < N; j++) { 
      C.content[i][j] = A.content[i][j] + B.content[i][j]; 
     } 
    } 
    return C; 
} 
+0

Matrix类在哪里?我在上面粘贴的代码中没有看到它...... – 2015-03-13 20:28:07

+0

所有那些新的Matrix()实例,在每一层递归中都不会很快,我认为。 – IVlad 2015-03-13 20:30:29

+4

您的Strassen实现从创建新的矩阵,复制矩阵元素和进行递归调用中有大量(理论上)不需要的开销。无论如何,一个高效的Strassen实现只比天真的算法快一点,并且只适用于足够大的矩阵。 – 2015-03-13 20:30:38

回答

2

我应该为Strassen的算法选择另一个叶子大小。所以我做了一个小实验。看起来叶片大小256最适合问题中包含的代码。下面用不同大小的叶子地块尺寸的随机矩阵1025 X 1025

leaf size

我比较Strassen's算法与叶大小256琐碎的算法,矩阵乘法每次用,就看这实际上是一种改进。事实证明这是一种改进,见下面的结果在不同大小的随机矩阵(以10为单位,每个大小重复50次)。 matrix size

下面的琐碎算法矩阵乘法代码:

// return C = A * B 
public Matrix times(Matrix B) { 
    Matrix A = this; 
    if (A.N != B.M) throw new RuntimeException("Illegal matrix dimensions."); 
    Matrix C = new Matrix(A.M, B.N); 
    for (int i = 0; i < C.M; i++) { 
     for (int j = 0; j < C.N; j++) { 
      for (int k = 0; k < A.N; k++) { 
       C.content[i][j] += (A.content[i][k] * B.content[k][j]); 
      } 
     } 
    } 
    return C; 
} 

它仍然认为可以做到在执行其他改进,但事实证明,叶大小是一个非常重要的因素。所有的实验都是在Ubuntu 14.04上运行的机器完成的,其规格如下:

CPU: Intel(R) Core(TM) i7-2600K CPU @ 3.40GHz 
Memory: 2 x 4GB DDR3 1333 MHz