strassen矩阵乘法

Strassen矩阵乘法简要解析

Strassen矩阵乘法具体描述如下:

两个n×n 阶的矩阵A与B的乘积是另一个n×n 阶矩阵C,C可表示为假如每一个C(i, j) 都用此公式计算,则计算C所需要的操作次数为n3 m+n2 (n- 1) a,其中m表示一次乘法,a 表示一次加法或减法。

为了使讨论简便,假设n 是2的幂(也就是说, n是1,2,4,8,1 6,...)。

首先,假设n= 1时是一个小问题,n> 1时为一个大问题。后面将根据需要随时修改这个假设。对于1×1阶的小矩阵,可以通过将两矩阵中的两个元素直接相乘而得到结果。

考察一个n> 1的大问题。可以将这样的矩阵分成4个n/ 2×n/ 2阶的矩阵A1,A2,A3,和A4。当n 大于1且n 是2的幂时,n/ 2也是2的幂。因此较小矩阵也满足前面对矩阵大小的假设。矩阵Bi 和Ci 的定义与此类似.

假定strassen矩阵分割方案仅用于n≥8的矩阵乘法,而对于n<8的矩阵乘法则直接利用公式进行计算。则n= 8时,8×8矩阵相乘需要7次4×4矩阵乘法和1 8次4×4矩阵加/减法。每次矩阵乘法需花费6 4m+ 4 8a次操作,每次矩阵加法或减法需花费1 6a次操作。因此总的操作次数为7 ( 6 4m+ 4 8a) + 1 8 ( 1 6a) = 4 4 8m+ 6 2 4a。而使用直接计算方法,则需要5 1 2m+ 4 4 8a次操作。要使S t r a s s e n方法比直接计算方法快,至少要求5 1 2-4 4 8次乘法的开销比6 2 4-4 4 8次加/减法的开销大。或者说一次乘法的开销应该大于近似2 . 7 5次加/减法的开销。

假定n<1 6的矩阵是一个“小”问题,strassen的分解方案仅仅用于n≥1 6的情况,对于n<1 6的矩阵相乘,直接利用公式。则当n= 1 6时使用分而治之算法需要7 ( 5 1 2m+ 4 4 8a) +1 8 ( 6 4a) = 3 5 8 4m+ 4 2 8 8a次操作。直接计算时需要4 0 9 6m+ 3 8 4 0a次操作。若一次乘法的开销与一次加/减法的开销相同,则strassen方法需要7872次操作及用于问题分解的额外时间,而直接计算方法则需要7936次操作加上程序中执行for循环以及其他语句所花费的时间。即使直接计算方法所需要的操作次数比strassen方法少,但由于直接计算方法需要更多的额外开销,因此它也不见得会比strassen方法快。

n 的值越大,Strassen 方法与直接计算方法所用的操作次数的差异就越大,因此对于足够大的n,Strassen 方法将更快。设t (n) 表示使用Strassen 分而治之方法所需的时间。因为大的矩阵会被递归地分割成小矩阵直到每个矩阵的大小小于或等于k(k至少为8,也许更大,具体值由计算机的性能决定). 用迭代方法计算,可得t(n) = (nlog27)。因为log27 ≈2 . 8 1,所以与直接计算方法的复杂性(n3)相比,分而治之矩阵乘法算法有较大的改进。

再次说明:

矩阵C = AB,可写为
C11 = A11B11 + A12B21
C12 = A11B12 + A12B22
C21 = A21B11 + A22B21
C22 = A21B12 + A22B22
如果A、B、C都是二阶矩阵,则共需要8次乘法和4次加法。如果阶大于2,可以将矩阵分块进行计算。耗费的时间是O(nE3)。

要改进算法计算时间的复杂度,必须减少乘法运算次数。按分治法的思想,Strassen提出一种新的方法,用7次乘法完成2阶矩阵的乘法,算法如下:
M1 = A11(B12 - B12)
M2 = (A11 + A12)B22
M3 = (A21 + A22)B11
M4 = A22(B21 - B11)
M5 = (A11 + A22)(B11 + B22)
M6 = (A12 - A22)(B21 + B22)
M7 = (A11 - A21)(B11 + B12)
完成了7次乘法,再做如下加法:
C11 = M5 + M4 - M2 + M6
C12 = M1 + M2
C21 = M3 + M4
C22 = M5 + M1 - M3 - M7
全部计算使用了7次乘法和18次加减法,计算时间降低到O(nE2.81)。计算复杂性得到较大改进。


STRASSEN矩阵乘法算法如下:

#include <iostream.h>

const int N=8; //常量N用来定义矩阵的大小

void main()
{

   void STRASSEN(int n,float A[][N],float B[][N],float C[][N]);
   void input(int n,float p[][N]);
   void output(int n,float C[][N]);                    //函数声明部分

   float A[N][N],B[N][N],C[N][N];  //定义三个矩阵A,B,C

   cout<<"现在录入矩阵A[N][N]:"<<endl<<endl;
   input(N,A);
   cout<<endl<<"现在录入矩阵B[N][N]:"<<endl<<endl;
   input(N,B);                         //录入数组

   STRASSEN(N,A,B,C);   //调用STRASSEN函数计算

   output(N,C);  //输出计算结果
}


void input(int n,float p[][N])  //矩阵输入函数
{
   int i,j;

   for(i=0;i<n;i++)
   {
       cout<<"请输入第"<<i+1<<"行"<<endl;
       for(j=0;j<n;j++)
           cin>>p[i][j];
   }
}

void output(int n,float C[][N]) //据矩阵输出函数
{
   int i,j;
   cout<<"输出矩阵:"<<endl;
   for(i=0;i<n;i++)
   {
       cout<<endl;
       for(j=0;j<n;j++)
           cout<<C[i][j]<<"  ";
   }
   cout<<endl<<endl;

}

void MATRIX_MULTIPLY(float A[][N],float B[][N],float C[][N])  //按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
{
   int i,j,t;
   for(i=0;i<2;i++)                     //计算A*B-->C
       for(j=0;j<2;j++)
       {    
           C[i][j]=0;                   //计算完一个C[i][j],C[i][j]应重新赋值为零
           for(t=0;t<2;t++)
           C[i][j]=C[i][j]+A[i][t]*B[t][j];
       }
}

void MATRIX_ADD(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵加法函数X+Y―>Z
{
   int i,j;
   for(i=0;i<n;i++)
       for(j=0;j<n;j++)
           Z[i][j]=X[i][j]+Y[i][j];
}

void MATRIX_SUB(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵减法函数X-Y―>Z
{
   int i,j;
   for(i=0;i<n;i++)
       for(j=0;j<n;j++)
           Z[i][j]=X[i][j]-Y[i][j];

}


void STRASSEN(int n,float A[][N],float B[][N],float C[][N])  //STRASSEN函数(递归)
{
   float A11[N][N],A12[N][N],A21[N][N],A22[N][N];
   float B11[N][N],B12[N][N],B21[N][N],B22[N][N];
   float C11[N][N],C12[N][N],C21[N][N],C22[N][N];
   float M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],M6[N][N],M7[N][N];
   float AA[N][N],BB[N][N],MM1[N][N],MM2[N][N];

   int i,j;//,x;


   if (n==2)
       MATRIX_MULTIPLY(A,B,C);//按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
   else
   {
       for(i=0;i<n/2;i++)                        
           for(j=0;j<n/2;j++)

               {
                   A11[i][j]=A[i][j];
                   A12[i][j]=A[i][j+n/2];
                   A21[i][j]=A[i+n/2][j];
                   A22[i][j]=A[i+n/2][j+n/2];
                   B11[i][j]=B[i][j];
                   B12[i][j]=B[i][j+n/2];
                   B21[i][j]=B[i+n/2][j];
                   B22[i][j]=B[i+n/2][j+n/2];
               }                                   //将矩阵A和B式分为四块




   MATRIX_SUB(n/2,B12,B22,BB);                      

   STRASSEN(n/2,A11,BB,M1);//M1=A11(B12-B22)

   MATRIX_ADD(n/2,A11,A12,AA);
   STRASSEN(n/2,AA,B22,M2);//M2=(A11+A12)B22

   MATRIX_ADD(n/2,A21,A22,AA);
   STRASSEN(n/2,AA,B11,M3);//M3=(A21+A22)B11

   MATRIX_SUB(n/2,B21,B11,BB);
   STRASSEN(n/2,A22,BB,M4);//M4=A22(B21-B11)

   MATRIX_ADD(n/2,A11,A22,AA);
   MATRIX_ADD(n/2,B11,B22,BB);
   STRASSEN(n/2,AA,BB,M5);//M5=(A11+A22)(B11+B22)

   MATRIX_SUB(n/2,A12,A22,AA);
   MATRIX_SUB(n/2,B21,B22,BB);
   STRASSEN(n/2,AA,BB,M6);//M6=(A12-A22)(B21+B22)

   MATRIX_SUB(n/2,A11,A21,AA);
   MATRIX_SUB(n/2,B11,B12,BB);
   STRASSEN(n/2,AA,BB,M7);//M7=(A11-A21)(B11+B12)
                                                   //计算M1,M2,M3,M4,M5,M6,M7(递归部分)


   MATRIX_ADD(N/2,M5,M4,MM1);                        
   MATRIX_SUB(N/2,M2,M6,MM2);
   MATRIX_SUB(N/2,MM1,MM2,C11);//C11=M5+M4-M2+M6

   MATRIX_ADD(N/2,M1,M2,C12);//C12=M1+M2

   MATRIX_ADD(N/2,M3,M4,C21);//C21=M3+M4

   MATRIX_ADD(N/2,M5,M1,MM1);
   MATRIX_ADD(N/2,M3,M7,MM2);
   MATRIX_SUB(N/2,MM1,MM2,C22);//C22=M5+M1-M3-M7

   for(i=0;i<n/2;i++)
       for(j=0;j<n/2;j++)
       {
           C[i][j]=C11[i][j];
           C[i][j+n/2]=C12[i][j];
           C[i+n/2][j]=C21[i][j];
           C[i+n/2][j+n/2]=C22[i][j];
       }                                            //计算结果送回C[N][N]

   }


你可能感兴趣的:(矩阵乘法)