目录
一.矩阵乘法的嵌套循环算法
二.矩阵乘法的递归算法
三.矩阵乘法的Strassen算法
伪代码:
C++代码:
//1.矩阵乘法的嵌套循环算法
#include
using namespace std;
void Square_MA_MU(int a[][3],int b[][3],int c[][3],int n) //传递二维数组参数时必须要确定列数
{
for (int i = 0; i < n; i++) //行遍历
{
for(int j=0;j
伪代码:
C++代码:
#include
using namespace std;
void matrix_multi_recursive(int a[][8], int m, int n, int b[][8], int p, int q, int size, int c[][8]) //m,n是矩阵a的位置参数,p,q是矩阵b
{
if (size == 1)
{
c[m][q] += a[m][n] * b[p][q];
}
else
{
int half_size = size / 2;
//分块进行递归运算,分块的基本原则为 half_size 即为 size/2
//初始状态下共分八块,再依次递归求得矩阵相乘的结果
matrix_multi_recursive(a, m, n, b, p, q, half_size, c);
matrix_multi_recursive(a, m, n + half_size, b, p + half_size, q, half_size, c);
matrix_multi_recursive(a, m, n, b, p, q + half_size, half_size, c);
matrix_multi_recursive(a, m, n + half_size, b, p + half_size, q + half_size, half_size, c);
matrix_multi_recursive(a, m + half_size, n, b, p, q, half_size, c);
matrix_multi_recursive(a, m + half_size, n + half_size, b, p + half_size, q, half_size, c);
matrix_multi_recursive(a, m + half_size, n, b, p, q + half_size, half_size, c);
matrix_multi_recursive(a, m + half_size, n + half_size, b, p + half_size, q + half_size, half_size, c);
}
}
void print(int c[][8], int size)
{
int i, j;
for (i = 0; i < size; i++)
{
for (j = 0; j < size; j++)
{
cout << c[i][j] << " ";
}
cout << endl;
}
}
int main()
{
int a[8][8] =
{
{1,2,3,4,5,6,7,8},
{2,3,4,5,6,7,8,9},
{3,4,5,6,7,8,9,10},
{4,5,6,7,8,9,10,11},
{5,6,7,8,9,10,11,12},
{6,7,8,9,10,11,12,13},
{7,8,9,10,11,12,13,14},
{8,9,10,11,12,13,14,15}
};
int b[8][8] =
{
{10,11,12,13,14,15,16,17},
{11,12,13,14,15,16,17,18},
{12,13,14,15,16,17,18,19},
{13,14,15,16,17,18,19,20},
{14,15,16,17,18,19,20,21},
{15,16,17,18,19,20,21,22},
{16,17,18,19,20,21,22,23},
{17,18,19,20,21,22,23,24}
};
int c[8][8];
for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
c[i][j] = 0;
}
}
matrix_multi_recursive(a, 0, 0, b, 0, 0, 8, c);
print(c, 8);
return 0;
system("pause");
return 0;
}
原理:矩阵乘法的分块运算:
【2.5】矩阵分块相乘 - 知乎 (zhihu.com)
复杂度:
伪代码:
C++代码:
#include "stdafx.h"
#include
#include
#include
#include
using namespace std;
template
class Strassen
{
public:
void ADD(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void SUB(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void NormalMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void StrassenMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size);
void FillMatrix(T ** MatrixA, T ** MatrixB, int size);//给A、B矩阵赋初值
int GetMatrixSum(T ** Matrix, int size);
//用来计算矩阵各个元素的和,如果两种算法得出的矩阵的和相等则认为算法正确。
};
template
void Strassen::ADD(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
}
}
}
template
void Strassen::SUB(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
}
}
}
template
void Strassen::NormalMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixResult[i][j] = 0;
for(int k = 0; k < size; k++)
MatrixResult[i][j] += MatrixA[i][k] * MatrixB[k][j];
}
}
}
template
void Strassen::FillMatrix(T ** MatrixA, T ** MatrixB, int size)//给A、B矩阵赋初值
{
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
MatrixA[i][j] = MatrixB[i][j] = rand() % 5;
}
}
}
template
void Strassen::StrassenMul(T ** MatrixA, T ** MatrixB, T ** MatrixResult, int size)
{
// if ( size <= 64 )
//分治门槛,小于这个值时不再进行递归计算,而是采用常规矩阵计算方法
// {
// NormalMul(MatrixA, MatrixB, MatrixResult, size);
// }
if(size == 1)
{
MatrixResult[0][0] = MatrixA[0][0] * MatrixB[0][0];
}
else
{
int half_size = size / 2;
T ** A11; T ** A12; T ** A21; T ** A22;
T ** B11; T ** B12; T ** B21; T ** B22;
T ** C11; T ** C12; T ** C21; T** C22;
T ** M1; T ** M2; T ** M3; T ** M4; T ** M5; T ** M6; T ** M7;
T ** MatrixTemp1; T ** MatrixTemp2;
A11 = new int * [half_size];
A12 = new int * [half_size];
A21 = new int * [half_size];
A22 = new int * [half_size];
B11 = new int * [half_size];
B12 = new int * [half_size];
B21 = new int * [half_size];
B22 = new int * [half_size];
C11 = new int * [half_size];
C12 = new int * [half_size];
C21 = new int * [half_size];
C22 = new int * [half_size];
M1 = new int * [half_size];
M2 = new int * [half_size];
M3 = new int * [half_size];
M4 = new int * [half_size];
M5 = new int * [half_size];
M6 = new int * [half_size];
M7 = new int * [half_size];
MatrixTemp1 = new int * [half_size];
MatrixTemp2 = new int * [half_size];
for(int i = 0; i < half_size; i++)
{
A11[i] = new int[half_size];
A12[i] = new int[half_size];
A21[i] = new int[half_size];
A22[i] = new int[half_size];
B11[i] = new int[half_size];
B12[i] = new int[half_size];
B21[i] = new int[half_size];
B22[i] = new int[half_size];
C11[i] = new int[half_size];
C12[i] = new int[half_size];
C21[i] = new int[half_size];
C22[i] = new int[half_size];
M1[i] = new int[half_size];
M2[i] = new int[half_size];
M3[i] = new int[half_size];
M4[i] = new int[half_size];
M5[i] = new int[half_size];
M6[i] = new int[half_size];
M7[i] = new int[half_size];
MatrixTemp1[i] = new int[half_size];
MatrixTemp2[i] = new int[half_size];
}
//赋值
for(int i = 0; i < half_size; i++)
{
for(int j = 0; j < half_size; j++)
{
A11[i][j] = MatrixA[i][j];
A12[i][j] = MatrixA[i][j+half_size];
A21[i][j] = MatrixA[i+half_size][j];
A22[i][j] = MatrixA[i+half_size][j+half_size];
B11[i][j] = MatrixB[i][j];
B12[i][j] = MatrixB[i][j+half_size];
B21[i][j] = MatrixB[i+half_size][j];
B22[i][j] = MatrixB[i+half_size][j+half_size];
}
}
//calculate M1
ADD(A11, A22, MatrixTemp1, half_size);
ADD(B11, B22, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M1,half_size);
//calculate M2
ADD(A21, A22, MatrixTemp1, half_size);
StrassenMul(MatrixTemp1, B11, M2, half_size);
//calculate M3
SUB(B12, B22, MatrixTemp1, half_size);
StrassenMul(A11, MatrixTemp1, M3, half_size);
//calculate M4
SUB(B21, B11, MatrixTemp1, half_size);
StrassenMul(A22, MatrixTemp1, M4, half_size);
//calculate M5
ADD(A11, A12, MatrixTemp1, half_size);
StrassenMul(MatrixTemp1, B22, M5, half_size);
//calculate M6
SUB(A21, A11, MatrixTemp1, half_size);
ADD(B11, B12, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M6, half_size);
//calculate M7
SUB(A12, A22, MatrixTemp1, half_size);
ADD(B21, B22, MatrixTemp2, half_size);
StrassenMul(MatrixTemp1, MatrixTemp2, M7, half_size);
//C11
ADD(M1, M4, C11, half_size);
SUB(C11, M5, C11, half_size);
ADD(C11, M7, C11, half_size);
//C12
ADD(M3, M5, C12, half_size);
//C21
ADD(M2, M4, C21, half_size);
//C22
SUB(M1, M2, C22, half_size);
ADD(C22, M3, C22, half_size);
ADD(C22, M6, C22, half_size);
//赋值
for(int i = 0; i < half_size; i++)
{
for(int j = 0; j < half_size; j++)
{
MatrixResult[i][j] = C11[i][j];
MatrixResult[i][j+half_size] = C12[i][j];
MatrixResult[i+half_size][j] = C21[i][j];
MatrixResult[i+half_size][j+half_size] = C22[i][j];
}
}
//释放申请的内存
for(int i = 0; i < half_size; i++)
{
delete[] A11[i];
delete[] A12[i];
delete[] A21[i];
delete[] A22[i];
delete[] B11[i];
delete[] B12[i];
delete[] B21[i];
delete[] B22[i];
delete[] C11[i];
delete[] C12[i];
delete[] C21[i];
delete[] C22[i];
delete[] M1[i];
delete[] M2[i];
delete[] M3[i];
delete[] M4[i];
delete[] M5[i];
delete[] M6[i];
delete[] M7[i];
delete[] MatrixTemp1[i];
delete[] MatrixTemp2[i];
}
delete[] A11;
delete[] A12;
delete[] A21;
delete[] A22;
delete[] B11;
delete[] B12;
delete[] B21;
delete[] B22;
delete[] C11;
delete[] C12;
delete[] C21;
delete[] C22;
delete[] M1;
delete[] M2;
delete[] M3;
delete[] M4;
delete[] M5;
delete[] M6;
delete[] M7;
delete[] MatrixTemp1;
delete[] MatrixTemp2;
}
}
template
int Strassen::GetMatrixSum(T ** Matrix, int size)
{
int sum = 0;
for(int i = 0; i < size; i++)
{
for(int j = 0; j < size; j++)
{
sum += Matrix[i][j];
}
}
return sum;
}
int main()
{
long startTime_normal, endTime_normal;
long startTime_strasse, endTime_strassen;
//srand(time(0));
Strassen stra;
int N;
cout<<"please input the size of Matrix,and the size must be the power of 2:"<>N;
int ** Matrix1 = new int * [N];
int ** Matrix2 = new int * [N];
int ** Matrix3 = new int * [N];
for(int i=0;i
核心思想:令递归树不那么茂盛一点,即只进行七次递归而不是八次。
复杂度: