矩阵相乘strassen-c++代码实现及运行实例结果

理论知识参见算法导论第三版以及百度


伪代码

矩阵相乘strassen-c++代码实现及运行实例结果_第1张图片


c++代码

#include 

using namespace std;

#define N 2//以二维方阵为例

template//使用模板保证矩阵可以为不同类型如int,double等
void output(T D[N][N],int n);

template
void strassen(T A[N][N],T B[N][N],T C[N][N],int n);

int main()
{
    int A[N][N]={1,2,3,4};
    cout<<"矩阵A的值"<
void output(T D[N][N],int n)
{
    for(int i=0;i
void matrixAdd(T a[N][N],T b[N][N],T c[N][N],int n)//定义矩阵加法
{
    for(int i=0;i
void matrixSub(T a[N][N],T b[N][N],T c[N][N],int n)//定义矩阵减法
{
    for(int i=0;i
void matrixMul(T a[N][N],T b[N][N],T c[N][N])//定义矩阵乘法
{
    for(int i=0;i<2;++i)//小于2是因为strassen采用递归,递归结束标志是最终分成二阶矩阵
        for(int j=0;j<2;++j)
        {
            c[i][j]=0;
            for(int k=0;k<2;++k)
                c[i][j]+=a[i][k]*b[k][j];
        }
}

template
void strassen(T A[N][N],T B[N][N],T C[N][N],int n)
{
    T A11[N][N],A12[N][N],A21[N][N],A22[N][N];//将矩阵分块
    T B11[N][N],B12[N][N],B21[N][N],B22[N][N];
    T C11[N][N],C12[N][N],C21[N][N],C22[N][N];
    T S1[N][N],S2[N][N],S3[N][N],S4[N][N],S5[N][N],S6[N][N],S7[N][N];//由strassen定义的7个系数
    T temp1[N][N],temp2[N][N];//存储中间量
    if(n==2)//递归结束标志
        matrixMul(A,B,C);
    else
    {
        for(int i=0;i

运行结果

矩阵相乘strassen-c++代码实现及运行实例结果_第2张图片



你可能感兴趣的:(算法导论)