HUST算法实践_POJ3233

题目传送门

问题描述

Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.

输入

The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.

输出

Output the elements of S modulo m in the same way as A is given.

数据结构

采用一维数组存储矩阵, 某元素的行数用标号 / n得到, 列数用标号 % n得到.

算法思想

采用倍增算法(对倍增算法不熟悉的同学可以移步https://blog.csdn.net/MaTF_/article/details/122976711?spm=1001.2014.3001.5502), 设置两个数组:

  1. a[31][maxn]: a[i]表示 A 2 i − 1 A^{2^{i-1}} A2i1对应的矩阵, 特别地, a[0]为单位矩阵
  2. b[31][maxn]: b[i]表示 A 1 + A 2 + . . . + A 2 i − 1 A^1+ A^2+...+A^{2^{i-1}} A1+A2+...+A2i1, 特别地, b[0]为空矩阵

我们可以通过以下方法得到上述的两个数组(涉及到的运算均为矩阵运算, 需通过相应的函数实现):

  1. a[i] = a[i] * a[i]
  2. b[i] = b[i-1] + b[i-1]*a[i-1]

得到数组ab后, 我们对问题进行求解, 核心代码如下:

void solve(){
	int cnt=1;
	int temp_cnt[maxn];
	cpy(temp_cnt,a[0]);		//temp为当前处理过的次数最高的A的幂
	for(int i=29;i>=0;i--){
		if(k>=1<<i){
			mul(b[i+1],temp_cnt);
			pls(ans,b[i+1]);
			mul(temp_cnt,a[i+1]);
			k-=1<<i;
		}
	}
}

代码实现

#include
#include
#define maxn 901
using namespace std;

/*倍增算法*/ 

int ans[maxn];
int c[maxn];
int a[31][maxn],b[31][maxn]; 
int n,k,m;

void pls(int x[maxn],int y[maxn]){
	for(int i=0;i<n*n;i++){
		x[i]+=y[i];
		x[i]%=m;
	}
} 
void prt(){
	for(int i=0;i<n*n;i++){
		cout<<ans[i]<<' ';
		if(i%n==n-1) cout<<'\n';
	}
}
void mul(int x[maxn],int y[maxn]){
	int temp[maxn];
	memset(temp,0,sizeof(temp));
	for(int i=0;i<n;i++){
		for(int j=0;j<n;j++){
			for(int k=0;k<n;k++){
				temp[i*n+j]+=(x[i*n+k]*y[k*n+j])%m;
				temp[i*n+j]%=m;
			} 
			
		}
	}
	for(int i=0;i<n*n;i++) x[i]=temp[i];
}
void cpy(int x[maxn],int y[maxn]){
	for(int i=0;i<n*n;i++) x[i]=y[i];
}
void pre(){		//注意:a[1]对应2^0! 
	cpy(a[1],c);
	cpy(b[1],c);
	for(int i=2;i<=30;i++){
		cpy(a[i],a[i-1]);
		mul(a[i],a[i-1]);
		cpy(b[i],b[i-1]);
		mul(b[i],a[i-1]);
		pls(b[i],b[i-1]);
	}
}
void solve(){
	int cnt=1;
	int temp_cnt[maxn];
	cpy(temp_cnt,a[0]);
	for(int i=29;i>=0;i--){
		if(k>=1<<i){
			mul(b[i+1],temp_cnt);
			pls(ans,b[i+1]);
			mul(temp_cnt,a[i+1]);
			k-=1<<i;
		}
	}
}
signed main(){
	cin>>n>>k>>m;
	for(int i=0;i<n*n;i++){
		cin>>c[i]; 
		if(i%n==i/n) a[0][i]=1;
		else a[0][i]=0;
	}
	pre();
	solve();
	prt();
	return 0;
}

你可能感兴趣的:(算法,数据结构)