Strassen矩阵乘法(C++)

思路

两个矩阵A,B相乘时.有以下三种方法

暴力计算法. 三个for循环, 这时候时间复杂度为O(n^3).因为Cij=∑(k=1->n)Aik*Bkj,需要一个循环, 且C中有n^2个元素, 所以时间复杂度为O(n^3)

分治法. 首先将A,B,C分成相等大小的方块矩阵.

所以C11=A11*B11+A12*B21, C12=A11*B12+A12*B22,

C21=A21*B11+A22*B21, C22=A21*B12+A22*B22

用T(n)表示n*n矩阵的乘法, 所以有T(n)=8T(n/2)+Θ(n^2). 其中, 8T(n/2)表示8次子矩阵乘法, 子矩阵的规模为n/2 * n/2. θ(n^2)表示4次矩阵加法的时间复杂度以及合并C矩阵的时间复杂度.最后结果是Θ(n^3)与暴力计算时间复杂度相同.

Strassen算法,可以将时间复杂度优化到O(n^log7).

现在重新定义7个新矩阵

M1=(A11+A22)*(B11+B22)

M2=(A21+A22)*B11

M3=A11*(B12-B22)

M4=A22*(B21-B11)

M5=(A11+A12)*B22

M6=(A21-A11)*(B11+B12)

M7=(A12-A22)*(B21+B22)

结果矩阵C可以组合上述矩阵,如下

C11=M1+M4-M5+M7

C12=M3+M5

C21=M2+M4

C22=M1-M2+M3+M6

这时候共用了7次乘法,18次加减法运算. 写出递推公式T(n)=7T(n/2)+Θ(n^2). 最终结果是O(n^log7)=O(n^2.807).

代码如下:

#include 

using namespace std;

// 矩阵相乘的暴力求解
void MUL(int** MatrixA,int** MatrixB,int** MatrixResult,int Msize){
    for(int i=0;i> MSize;

    // 定义三个矩阵
    int** MatrixA;
    int** MatrixB;
    int** MatrixC;

    // 初始化三个矩阵
    MatrixA=new int*[MSize];
    MatrixB=new int*[MSize];
    MatrixC=new int*[MSize];
    for(int i=0;i> MatrixA[i][j];
        }
    }
    for(int i=0;i> MatrixB[i][j];
        }
    }

    Strassen(MSize,MatrixA,MatrixB,MatrixC);

    // 打印输出结果矩阵
    for(int i=0;i

你可能感兴趣的:(C++,矩阵,线性代数,算法)