算法导论借助暴力求两NxN矩阵乘积的问题,引出了Strassen算法。下面的代码实现分别对应了书中暴力求解法、分治求解法和Strassen求解法的实现,具体如下文所示。
关于这部分内容的伪代码可及说明可以参看《算法导论》4.2章节。
根据矩阵的乘法知识,两个NxN的矩阵A和B相乘的结果矩阵C的暴力算法是:
/**
* 一般的暴力矩阵乘法运算;矩阵A和B都是NxN的方阵
*
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return
* 矩阵A和B相乘得到的矩阵C
*/
public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < rows; j++) {
C[i][j] = 0;
for (int k = 0; k < rows; k++) {
C[i][j] = C[i][j] + A[i][k] * B[k][j];
}
}
}
return C;
}
加入了分治思想的两个NxN的矩阵A和B得到相乘结果矩阵C的算法是:
/**
* 使用分治算法的NxN矩阵乘法运算
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return
*/
public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN矩阵加法
*
* @param srcA
* 加法源矩阵之一
* @param srcB
* 加法源矩阵之二
* @param dest
* 矩阵加法结果
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* 打印NxN矩阵
*
*/
public static void displaySquare(int matrix[][]) {
for (int i = 0; i < matrix.length; i++) {
for (int j : matrix[i]) {
System.out.print(j + " ");
}
System.out.println();
}
}
使用Strassen算法求两方阵矩阵的积的算法实现代码是:
/**
* Strassen算法的NxN矩阵乘法运算
*
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return
*/
public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] S1 = new int[rows / 2][rows / 2];
int[][] S2 = new int[rows / 2][rows / 2];
int[][] S3 = new int[rows / 2][rows / 2];
int[][] S4 = new int[rows / 2][rows / 2];
int[][] S5 = new int[rows / 2][rows / 2];
int[][] S6 = new int[rows / 2][rows / 2];
int[][] S7 = new int[rows / 2][rows / 2];
int[][] S8 = new int[rows / 2][rows / 2];
int[][] S9 = new int[rows / 2][rows / 2];
int[][] S10 = new int[rows / 2][rows / 2];
squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12
int[][] P1 = new int[rows / 2][rows / 2];
int[][] P2 = new int[rows / 2][rows / 2];
int[][] P3 = new int[rows / 2][rows / 2];
int[][] P4 = new int[rows / 2][rows / 2];
int[][] P5 = new int[rows / 2][rows / 2];
int[][] P6 = new int[rows / 2][rows / 2];
int[][] P7 = new int[rows / 2][rows / 2];
P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
int[][] temp = new int[rows / 2][rows / 2];
// C11 = P5 + P4 - P2 + P6
squareMatrixElementAdd(P5, P4, temp);
squareMatrixElementSub(temp, P2, temp);
squareMatrixElementAdd(temp, P6, C11);
// C12 = P1 + P2
squareMatrixElementAdd(P1, P2, C12);
// C21 = P3 + P4
squareMatrixElementAdd(P3, P4, C21);
// C22 = P5 + P1 - P3 -P7
squareMatrixElementAdd(P5, P1, temp);
squareMatrixElementSub(temp, P3, temp);
squareMatrixElementSub(temp, P7, C22);
//将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN矩阵加法
*
* @param srcA
* 加法源矩阵之一
* @param srcB
* 加法源矩阵之二
* @param dest
* 矩阵加法结果
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* NxN矩阵减法
*
* @param srcA
* 减法源矩阵之一
* @param srcB
* 减法源矩阵之二
* @param dest
* 矩阵减法结果
*/
public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] - srcB[i][j];
}
最后本文中涉及到的完整测试如下:
public class StrassenAlgor {
static int[][] A = {
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 }
};
static int[][] B = {
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 3, 1 },
{ 1, 2, 2, 1 }
};
public static void main(String[] args) {
System.out.println("使用暴力迭代形式的方阵矩阵求积");
int[][] C = martixMultiplyRecursive(A, B);
displaySquare(C);
System.out.println("使用分治思想的普通形式的方阵矩阵求积");
int[][] C1 = martixMultiplyRecursive(A, B);
displaySquare(C1);
System.out.println("Strassen 方阵求积");
int[][] C2 = strassenMartixMultiplyRecursive(A, B);
displaySquare(C2);
}
/**
* 一般的暴力矩阵乘法运算;矩阵A和B都是NxN的方阵
*
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return 矩阵A和B相乘得到的矩阵C
*/
public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < rows; j++) {
C[i][j] = 0;
for (int k = 0; k < rows; k++) {
C[i][j] = C[i][j] + A[i][k] * B[k][j];
}
}
}
return C;
}
/**
* 使用分治算法的NxN矩阵乘法运算
*
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return
*/
public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);
// 将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* Strassen算法的NxN矩阵乘法运算
*
* @param A
* 参加运算的矩阵之一A
* @param B
* 参加运算的矩阵之一B
* @return
*/
public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] S1 = new int[rows / 2][rows / 2];
int[][] S2 = new int[rows / 2][rows / 2];
int[][] S3 = new int[rows / 2][rows / 2];
int[][] S4 = new int[rows / 2][rows / 2];
int[][] S5 = new int[rows / 2][rows / 2];
int[][] S6 = new int[rows / 2][rows / 2];
int[][] S7 = new int[rows / 2][rows / 2];
int[][] S8 = new int[rows / 2][rows / 2];
int[][] S9 = new int[rows / 2][rows / 2];
int[][] S10 = new int[rows / 2][rows / 2];
squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12
int[][] P1 = new int[rows / 2][rows / 2];
int[][] P2 = new int[rows / 2][rows / 2];
int[][] P3 = new int[rows / 2][rows / 2];
int[][] P4 = new int[rows / 2][rows / 2];
int[][] P5 = new int[rows / 2][rows / 2];
int[][] P6 = new int[rows / 2][rows / 2];
int[][] P7 = new int[rows / 2][rows / 2];
P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
int[][] temp = new int[rows / 2][rows / 2];
// C11 = P5 + P4 - P2 + P6
squareMatrixElementAdd(P5, P4, temp);
squareMatrixElementSub(temp, P2, temp);
squareMatrixElementAdd(temp, P6, C11);
// C12 = P1 + P2
squareMatrixElementAdd(P1, P2, C12);
// C21 = P3 + P4
squareMatrixElementAdd(P3, P4, C21);
// C22 = P5 + P1 - P3 -P7
squareMatrixElementAdd(P5, P1, temp);
squareMatrixElementSub(temp, P3, temp);
squareMatrixElementSub(temp, P7, C22);
// 将C11/C12/C21/C22四个子矩阵合并为最终的结果C矩阵
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* 将一个NxN的大矩阵分解成4个N/2xN/2的子矩阵
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 将4个N/2xN/2的子矩阵合并成一个NxN的大矩阵
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN矩阵加法
*
* @param srcA
* 加法源矩阵之一
* @param srcB
* 加法源矩阵之二
* @param dest
* 矩阵加法结果
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* NxN矩阵减法
*
* @param srcA
* 减法源矩阵之一
* @param srcB
* 减法源矩阵之二
* @param dest
* 矩阵减法结果
*/
public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] - srcB[i][j];
}
/**
* 打印NxN矩阵
*
*/
public static void displaySquare(int[][] matrix) {
for (int i = 0; i < matrix.length; i++) {
for (int j : matrix[i]) {
System.out.print(j + " ");
}
System.out.println();
}
}
}
输出如下:
使用暴力迭代形式的方阵矩阵求积
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
使用分治思想的普通形式的方阵矩阵求积
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
Strassen 方阵求积
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
PS:
最后一直在想为什么Strassen算法会降低方阵求积的复杂度;最后在知乎上看到一个答案,感觉解释的挺好;现贴出来,一起分享下:
"strassen算法的关键不在于是乘法还是加法,而是在于算法内部递归调用的次数。strassen算法的关键在于内部递归调用的次数减少了1(从普通的8次变为特殊的7次)。这里的一个结论就是递归算法中递归调用次数少,时间复杂度底。这很容易理解,在算法导论中用了“茂盛”度来描述这一时间复杂度在递归算法中的变化。所以strassen算法的关键在于,递归调用的次数怎么从8次减少一次的。反推理解一下,这说明8次递归调用中有一次是冗余的,即第8次递归乘法的结果信息已经包含在了前7次的结果里,前7次的计算结果通过线性组合就能得到第8个递归的结果了。而该线性组合的时间复杂度低于该算法本身(即一次递归调用)的时间复杂度。做一个结论。但凡是能够优化时间复杂度的算法,高复杂度的算法中必然是有一些计算是冗余的,如能用更少的计算代替冗余,就能提高效率。(因为算法递归的刚好是乘法,所以此处看起来似乎是重点放在了乘法上)
至于为什么传统矩阵相乘算法中有冗余计算,也尝试分析一下:
冗余的根本原因应该在于基本的乘法分配律a*(b+c)=a*b+a*c。同样的计算结果,前一种(等号前)方法计算需要2次基本运算,而后一种(等号后)方法需要3次。(假设乘法运算和加法运算是同等开销的基本运算)。而一般的矩阵乘法算法中是大量的单步乘法运算后求和,即采用的是上述等号右边的计算式。如果能有一种方法,将乘法运算中的相同因子提到前边来,运用上述乘法分配律转换计算形式,那么就能提高计算效率。这应该就是strassen算法的本质。看strassen算法的过程,就是先将一部分子矩阵进行加(减)运算,再进行乘法运算。其实就是构造了上述分配律的左式