矩阵乘法的三种算法(蛮力嵌套循环法,分治递归法,Strassen法)

目录

一.矩阵乘法的嵌套循环算法

二.矩阵乘法的递归算法

三.矩阵乘法的Strassen算法


一.矩阵乘法的嵌套循环算法

伪代码:

矩阵乘法的三种算法(蛮力嵌套循环法,分治递归法,Strassen法)_第1张图片

C++代码:

//1.矩阵乘法的嵌套循环算法
#include
using namespace std;

void Square_MA_MU(int a[][3],int b[][3],int c[][3],int n) //传递二维数组参数时必须要确定列数
{
	for (int i = 0; i < n; i++)  //行遍历
	{
		for(int j=0;j

二.矩阵乘法的递归算法

伪代码:

矩阵乘法的三种算法(蛮力嵌套循环法,分治递归法,Strassen法)_第2张图片

C++代码:

#include
using namespace std;

void matrix_multi_recursive(int a[][8], int m, int n, int b[][8], int p, int q, int size, int c[][8]) //m,n是矩阵a的位置参数,p,q是矩阵b
{
	if (size == 1)
	{
		c[m][q] += a[m][n] * b[p][q];
	}
	else
	{
		int half_size = size / 2;
		//分块进行递归运算,分块的基本原则为 half_size 即为 size/2 
		//初始状态下共分八块,再依次递归求得矩阵相乘的结果
		matrix_multi_recursive(a, m, n, b, p, q, half_size, c);
		matrix_multi_recursive(a, m, n + half_size, b, p + half_size, q, half_size, c);
		matrix_multi_recursive(a, m, n, b, p, q + half_size, half_size, c);
		matrix_multi_recursive(a, m, n + half_size, b, p + half_size, q + half_size, half_size, c);
		matrix_multi_recursive(a, m + half_size, n, b, p, q, half_size, c);
		matrix_multi_recursive(a, m + half_size, n + half_size, b, p + half_size, q, half_size, c);
		matrix_multi_recursive(a, m + half_size, n, b, p, q + half_size, half_size, c);
		matrix_multi_recursive(a, m + half_size, n + half_size, b, p + half_size, q + half_size, half_size, c);
	}
}

void print(int c[][8], int size)
{
	int i, j;
	for (i = 0; i < size; i++)
	{
		for (j = 0; j < size; j++)
		{
			cout << c[i][j] << " ";
		}
		cout << endl;
	}
}



int main()
{
	int a[8][8] = 
	{ 
	{1,2,3,4,5,6,7,8},
	{2,3,4,5,6,7,8,9},
	{3,4,5,6,7,8,9,10},
	{4,5,6,7,8,9,10,11},
	{5,6,7,8,9,10,11,12},
	{6,7,8,9,10,11,12,13},
	{7,8,9,10,11,12,13,14},
	{8,9,10,11,12,13,14,15}
	};

	int b[8][8] =
	{
	{10,11,12,13,14,15,16,17},
	{11,12,13,14,15,16,17,18},
	{12,13,14,15,16,17,18,19},
	{13,14,15,16,17,18,19,20},
	{14,15,16,17,18,19,20,21},
	{15,16,17,18,19,20,21,22},
    {16,17,18,19,20,21,22,23},
	{17,18,19,20,21,22,23,24}
	};
	int c[8][8];
	for (int i = 0; i < 8; i++)
	{
		for (int j = 0; j < 8; j++)
		{
			c[i][j] = 0;
		}
	}
	matrix_multi_recursive(a, 0, 0, b, 0, 0, 8, c);
	print(c, 8);
	return 0;
	system("pause");
	return 0;
}

原理:矩阵乘法的分块运算:

【2.5】矩阵分块相乘 - 知乎 (zhihu.com)

复杂度:

三.矩阵乘法的Strassen算法

伪代码:

矩阵乘法的三种算法(蛮力嵌套循环法,分治递归法,Strassen法)_第3张图片

C++代码:

#include "stdafx.h"
#include 
#include 
#include 
#include 
using namespace std;
 
 
template 
class Strassen
{
public:
	void ADD(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void SUB(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void NormalMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void StrassenMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size);
	void FillMatrix(T **  MatrixA, T ** MatrixB, int size);//给A、B矩阵赋初值
	int   GetMatrixSum(T ** Matrix, int size);
	//用来计算矩阵各个元素的和,如果两种算法得出的矩阵的和相等则认为算法正确。
};
 
template 
void Strassen::ADD(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
		}
	}
}
 
template 
void Strassen::SUB(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
		}
	}
}

template 
void Strassen::NormalMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixResult[i][j] = 0;
			for(int k = 0; k < size; k++)
				MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
		}
	}
}
 
template 
void Strassen::FillMatrix(T **  MatrixA, T ** MatrixB, int size)//给A、B矩阵赋初值
{
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			MatrixA[i][j] = MatrixB[i][j] = rand() % 5; 
		}
	}	
}
 
template 
void Strassen::StrassenMul(T **  MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
	// if ( size <= 64 )    
	//分治门槛,小于这个值时不再进行递归计算,而是采用常规矩阵计算方法
	// {
	// 	NormalMul(MatrixA, MatrixB, MatrixResult, size);
	// }
	if(size == 1)
	{
		MatrixResult[0][0] = MatrixA[0][0] * MatrixB[0][0];
	}
	else
	{
		int half_size = size / 2;
		T ** A11; T ** A12; T ** A21; T ** A22;
		T ** B11; T ** B12; T ** B21; T ** B22;
		T ** C11; T ** C12; T ** C21; T** C22;
		T ** M1; T ** M2; T ** M3; T ** M4; T ** M5; T ** M6; T ** M7;
		T ** MatrixTemp1; T ** MatrixTemp2;
 
		A11 = new int * [half_size];
		A12 = new int * [half_size];
		A21 = new int * [half_size];
		A22 = new int * [half_size];
 
		B11 = new int * [half_size];
		B12 = new int * [half_size];
		B21 = new int * [half_size];
		B22 = new int * [half_size];
 
		C11 = new int * [half_size];
		C12 = new int * [half_size];
		C21 = new int * [half_size];
		C22 = new int * [half_size];
 
		M1 = new int * [half_size];
		M2 = new int * [half_size];
		M3 = new int * [half_size];
		M4 = new int * [half_size];
		M5 = new int * [half_size];
		M6 = new int * [half_size];
		M7 = new int * [half_size];
		MatrixTemp1 = new int * [half_size];
		MatrixTemp2 = new int * [half_size];
 
		for(int i = 0; i < half_size; i++)
		{
			A11[i] = new int[half_size];	
			A12[i] = new int[half_size];	
			A21[i] = new int[half_size];	
			A22[i] = new int[half_size];
			
			B11[i] = new int[half_size];	
			B12[i] = new int[half_size];	
			B21[i] = new int[half_size];	
			B22[i] = new int[half_size];
			
			C11[i] = new int[half_size];	
			C12[i] = new int[half_size];	
			C21[i] = new int[half_size];	
			C22[i] = new int[half_size];
 
			M1[i] = new int[half_size];	
			M2[i] = new int[half_size];	
			M3[i] = new int[half_size];	
			M4[i] = new int[half_size];
			M5[i] = new int[half_size];	
			M6[i] = new int[half_size];	
			M7[i] = new int[half_size];
 
			MatrixTemp1[i] = new int[half_size];	
			MatrixTemp2[i] = new int[half_size];
		}
 
		//赋值
		for(int i = 0; i < half_size; i++)
		{
			for(int j = 0; j < half_size; j++)
			{
				A11[i][j] = MatrixA[i][j];
				A12[i][j] = MatrixA[i][j+half_size];
				A21[i][j] = MatrixA[i+half_size][j];
				A22[i][j] = MatrixA[i+half_size][j+half_size];
 
				B11[i][j] = MatrixB[i][j];
				B12[i][j] = MatrixB[i][j+half_size];
				B21[i][j] = MatrixB[i+half_size][j];
				B22[i][j] = MatrixB[i+half_size][j+half_size];
			}
		}		
 
		//calculate M1
		ADD(A11, A22, MatrixTemp1, half_size);
		ADD(B11, B22, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M1,half_size);
 
		//calculate M2
		ADD(A21, A22, MatrixTemp1, half_size);
		StrassenMul(MatrixTemp1, B11, M2, half_size);
 
		//calculate M3
		SUB(B12, B22, MatrixTemp1, half_size);
		StrassenMul(A11, MatrixTemp1, M3, half_size);
 
 
		//calculate M4
		SUB(B21, B11, MatrixTemp1, half_size);
		StrassenMul(A22, MatrixTemp1, M4, half_size);
 
		//calculate M5
		ADD(A11, A12, MatrixTemp1, half_size);
		StrassenMul(MatrixTemp1, B22, M5, half_size);
 
		//calculate M6
		SUB(A21, A11, MatrixTemp1, half_size);
		ADD(B11, B12, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M6, half_size);
 
		//calculate M7
		SUB(A12, A22, MatrixTemp1, half_size);
		ADD(B21, B22, MatrixTemp2, half_size);
		StrassenMul(MatrixTemp1, MatrixTemp2, M7, half_size);
 
		//C11
		ADD(M1, M4, C11, half_size);
		SUB(C11, M5, C11, half_size);
		ADD(C11, M7, C11, half_size);
 
		//C12
		ADD(M3, M5, C12, half_size);
 
		//C21
		ADD(M2, M4, C21, half_size);
 
		//C22
		SUB(M1, M2, C22, half_size);
		ADD(C22, M3, C22, half_size);
		ADD(C22, M6, C22, half_size);
 
		//赋值
		for(int i = 0; i < half_size; i++)
		{
			for(int j = 0; j < half_size; j++)
			{
				MatrixResult[i][j] = C11[i][j];
				MatrixResult[i][j+half_size] = C12[i][j];
				MatrixResult[i+half_size][j] = C21[i][j];
				MatrixResult[i+half_size][j+half_size] = C22[i][j];
			}
		}
 
		//释放申请的内存
		for(int i = 0; i < half_size; i++)
		{
			delete[] A11[i];	
			delete[] A12[i];	
			delete[] A21[i];	
			delete[] A22[i];
			
			delete[] B11[i];	
			delete[] B12[i];	
			delete[] B21[i];	
			delete[] B22[i];
			
			delete[] C11[i];	
			delete[] C12[i];	
			delete[] C21[i];	
			delete[] C22[i];
 
			delete[] M1[i];	
			delete[] M2[i];	
			delete[] M3[i];	
			delete[] M4[i];	
			delete[] M5[i];	
			delete[] M6[i];	
			delete[] M7[i];
			
			delete[] MatrixTemp1[i];	
			delete[] MatrixTemp2[i];
		}
		delete[] A11;	
		delete[] A12;	
		delete[] A21;	
		delete[] A22;
		
		delete[] B11;	
		delete[] B12;	
		delete[] B21;	
		delete[] B22;
		
		delete[] C11;	
		delete[] C12;	
		delete[] C21;	
		delete[] C22;
 
		delete[] M1;	
		delete[] M2;	
		delete[] M3;	
		delete[] M4;	
		delete[] M5;	
		delete[] M6;	
		delete[] M7;
		
		delete[] MatrixTemp1;	
		delete[] MatrixTemp2;
	}
}
 
template 
int   Strassen::GetMatrixSum(T ** Matrix, int size)
{
	int sum = 0;
	for(int i = 0; i < size; i++)
	{
		for(int j = 0; j < size; j++)
		{
			sum += Matrix[i][j];
		}
	}	
	return sum;
}
 
int main()
{
	long startTime_normal, endTime_normal;
	long startTime_strasse, endTime_strassen;
 
	//srand(time(0));
 
	Strassen stra;
	int N;
	cout<<"please input the size of Matrix,and the size must be the power of 2:"<>N;
 
	int ** Matrix1 = new int * [N];
	int ** Matrix2 = new int * [N];
	int ** Matrix3 = new int * [N];
	for(int i=0;i

你可能感兴趣的:(算法导论,算法,算法)