【算法导论】矩阵乘法strassen算法

      矩阵运算在做科学运算时是必不可少的,如果采用matlab来计算,这倒也容易。但是如果是自己写c或者c++代码,一般而言,需要做三次循环,其时间复杂度就是O(n^3)。

【算法导论】矩阵乘法strassen算法_第1张图片


       上图给出了我们一般会采用的方法,就是对应元素相乘和相加。如果把C=A*B进行分解,可以看出,这里需要进行8次的乘法运算:

【算法导论】矩阵乘法strassen算法_第2张图片

分别是:

r = a * e + b * g ;
s = a * f  + b * h ;
t = c * e + d  * g; 
u = c * f + d * h;

本文介绍的算法就是strassen提出的,可以将8次乘法降为7次乘法,虽然只是一次乘法,但是其实一次算法耗时要比加减法多很多。处理的方法是写成:

p1 = a * ( f - h )

p2 = ( a + b ) *  h

p3 = ( c +d ) * e

p4 = d *  ( g - e )

p5 = ( a + d ) * ( e + h )

p6 =  ( b - d ) * ( g + h ) 

p7 = ( a - c ) * ( e + f )

那么只需要计算p1,p2,p3,p4,p5,p6,p7,然后

r  = p5 + p4 + p6 - p2

s = p1 + p2

t = p3 + p4

u = p5 + p1 - p3 - p7

这样,八次的乘法就变成了7次乘法和一次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );

c++代码如下:

/*
Strassen Algorithm Implementation in C++
Coded By: Seyyed Hossein Hasan Pour MatiKolaee in May 5 2010 .
Mazandaran University of Science and Technology,Babol,Mazandaran,Iran
--------------------------------------------
Email : [email protected]
YM    : [email protected]
Updated may 09 2010.
*/
#include 
#include 
#include 
#include 
#include 
using namespace std;

int Strassen(int n, int** MatrixA, int ** MatrixB, int ** MatrixC);//Multiplies Two Matrices recrusively.
int ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Adds two Matrices, and places the result in another Matrix
int SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//subtracts two Matrices , and places  the result in another Matrix
int MUL(int** MatrixA, int** MatrixB, int** MatrixResult, int length );//Multiplies two matrices in conventional way.
void FillMatrix( int** matrix1, int** matrix2, int length);//Fills Matrices with random numbers.
void PrintMatrix( int **MatrixA, int MatrixSize );//prints the Matrix content.

int main()
{

    int MatrixSize = 0;

    int** MatrixA;
    int** MatrixB;
    int** MatrixC;

    clock_t startTime_For_Normal_Multipilication ;
    clock_t endTime_For_Normal_Multipilication ;

    clock_t startTime_For_Strassen ;
    clock_t endTime_For_Strassen ;

    time_t start,end;

    srand(time(0));

    cout<>MatrixSize;

    int N = MatrixSize;//for readiblity.


    MatrixA = new int *[MatrixSize];
    MatrixB = new int *[MatrixSize];
    MatrixC = new int *[MatrixSize];

    for (int i = 0; i < MatrixSize; i++)
    {
        MatrixA[i] = new int [MatrixSize];
        MatrixB[i] = new int [MatrixSize];
        MatrixC[i] = new int [MatrixSize];
    }

    FillMatrix(MatrixA,MatrixB,MatrixSize);

  //*******************conventional multiplication test
        cout<<"Phase I started:  "<< (startTime_For_Normal_Multipilication = clock());

		MUL(MatrixA,MatrixB,MatrixC,MatrixSize);

        cout<<"\nPhase I ended: "<< (endTime_For_Normal_Multipilication = clock());

		cout<<"\nMatrix Result... \n";
	    PrintMatrix(MatrixC,MatrixSize);

  //*******************Strassen multiplication test
        cout<<"\nMultiplication started: "<< (startTime_For_Strassen = clock());

		Strassen( N, MatrixA, MatrixB, MatrixC );

		cout<<"\nMultiplication: "<<(endTime_For_Strassen = clock());


	cout<<"\nMatrix Result... \n";
	PrintMatrix(MatrixC,MatrixSize);

	cout<<"Matrix size "<


你可能感兴趣的:(Data,Structure,and,Algorithm)