算法导论学习笔记——4.2矩阵乘法的Strassen算法

4.2矩阵乘法的Strassen算法

Strassen算法的时间复杂度为o(n^lg7),
是一种简化矩阵乘法的算法
下面是矩阵乘法的伪代码算法导论学习笔记——4.2矩阵乘法的Strassen算法_第1张图片
直接给出我以前整理过的模板
这边其实可以特判0来进行一个小优化

void cheng(ll a[][N],ll b[][N],ll n)//a=a*b
{
     
	memset(tmp,0,sizeof(tmp));
	for(ll i=0;i<n;i++)
		for(ll j=0;j<n;j++)
			for(ll k=0;k<n;k++)
				tmp[i][j]=(tmp[i][j]+a[i][k]*b[k][j])%mod;
	for(ll i=0;i<n;i++)
		for(ll j=0;j<n;j++)
			a[i][j]=tmp[i][j];
}

时间复杂度:o(n^3)

下面是原书中的伪代码算法导论学习笔记——4.2矩阵乘法的Strassen算法_第2张图片
下面是原书中的Strassen算法(已翻译)
算法导论学习笔记——4.2矩阵乘法的Strassen算法_第3张图片
先把每个矩阵分割为4份,然后创建如下10个中间矩阵:

S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12

接着,计算7次矩阵乘法:

P1 = A11 • S1
P2 = S2 • B22
P3 = S3 • B11
P4 = A22 • S4
P5 = S5 • S6
P6 = S7 • S8
P7 = S9 • S10

最后,根据这7个结果就可以计算出C矩阵:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7

最后合并即可得到结果

下面是我整理的快速幂解法


void mi(ll a[][N],ll n)
{
     
	memset(res,0,sizeof(res));
	for(ll i=0;i<N;i++)
		res[i][i]=1;
	while(n!=0)
	{
     
		if(n&1!=0)
			cheng(res,a,N);
		cheng(a,a,N);
		n>>=1;
	}
}
	

全代码
1.数组写法

#include 
using namespace std;
typedef long long ll;
const int mod=9973;
const int N=10;
ll cs[N][N],tmp[N][N],res[N][N];
void cheng(ll a[][N],ll b[][N],ll n)//a=a*b
{
     
	memset(tmp,0,sizeof(tmp));
	for(ll i=0;i<n;i++)
		for(ll j=0;j<n;j++)
			for(ll k=0;k<n;k++)
				tmp[i][j]=(tmp[i][j]+a[i][k]*b[k][j])%mod;
	for(ll i=0;i<n;i++)
		for(ll j=0;j<n;j++)
			a[i][j]=tmp[i][j];
}
void mi(ll a[][N],ll n)
{
     
	memset(res,0,sizeof(res));
	for(ll i=0;i<N;i++)
		res[i][i]=1;
	while(n!=0)
	{
     
		if(n&1!=0)
			cheng(res,a,N);
		cheng(a,a,N);
		n>>=1;
	}
}
void solve()
{
     
	ll n1,k1;
	scanf("%lld%lld",&n1,&k1);
	for(ll i=0;i<n1;i++)
		for(ll j=0;j<n1;j++)
			scanf("%lld",&cs[i][j]);
	mi(cs,k1);
	ll ans=0;
	for(ll i=0;i<n1;i++)
		ans+=res[i][i],ans%=mod;
	printf("%lld\n",ans);
}
int main()
{
     

    ll o; 
    scanf("%lld",&o);
    while(o--)
    	solve();
    return 0;
}


2.结构体写法

#define MAXN 2
#define mod int(1e4)  
struct Matrix
{
     
	int mat[MAXN][MAXN];
	Matrix() {
     }
	Matrix operator*(Matrix const &b)const
	{
     
		Matrix res;
		memset(res.mat, 0, sizeof(res.mat));
		for (int i = 0 ;i < MAXN; i++)
			for (int j = 0; j < MAXN; j++)
				for (int k = 0; k < MAXN; k++)
					res.mat[i][j] = (res.mat[i][j]+this->mat[i][k] * b.mat[k][j])%mod;
		return res;
	}
};
Matrix pow_mod(Matrix base, int n)
{
     
	Matrix res;
	memset(res.mat, 0, sizeof(res.mat));
	for (int i = 0; i < MAXN; i++)
		res.mat[i][i] = 1;
	while (n > 0)
	{
     
		if (n & 1) res = res*base;
		base = base*base;
		n >>= 1;
	}
	return res;
}

矩阵快速幂算法解析:
要计算一个矩阵A的n次幂,先判断n的奇偶,
如果m是奇数,就让ans*=A,
然后A=A^2,n=n/2。
例如 n=5,
1.ans*=A,A=A^2,n=n/2
(ans=A,A=A^2,n=2)
2.A=A^2,n=n/2
(ans=A,A=A^4,n=1)
3.ans*=A,A=A^2,n=n/2
(ans=A^5,A=A ^8,n=0)
退出循环
时间复杂度o(lgn)

你可能感兴趣的:(算法导论)