Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.
The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.
Output the elements of S modulo m in the same way as A is given.
采用一维数组存储矩阵, 某元素的行数用标号 / n
得到, 列数用标号 % n
得到.
采用倍增算法(对倍增算法不熟悉的同学可以移步https://blog.csdn.net/MaTF_/article/details/122976711?spm=1001.2014.3001.5502), 设置两个数组:
a[31][maxn]
: a[i]
表示 A 2 i − 1 A^{2^{i-1}} A2i−1对应的矩阵, 特别地, a[0]为单位矩阵b[31][maxn]
: b[i]
表示 A 1 + A 2 + . . . + A 2 i − 1 A^1+ A^2+...+A^{2^{i-1}} A1+A2+...+A2i−1, 特别地, b[0]为空矩阵我们可以通过以下方法得到上述的两个数组(涉及到的运算均为矩阵运算, 需通过相应的函数实现):
a[i] = a[i] * a[i]
b[i] = b[i-1] + b[i-1]*a[i-1]
得到数组a
和b
后, 我们对问题进行求解, 核心代码如下:
void solve(){
int cnt=1;
int temp_cnt[maxn];
cpy(temp_cnt,a[0]); //temp为当前处理过的次数最高的A的幂
for(int i=29;i>=0;i--){
if(k>=1<<i){
mul(b[i+1],temp_cnt);
pls(ans,b[i+1]);
mul(temp_cnt,a[i+1]);
k-=1<<i;
}
}
}
#include
#include
#define maxn 901
using namespace std;
/*倍增算法*/
int ans[maxn];
int c[maxn];
int a[31][maxn],b[31][maxn];
int n,k,m;
void pls(int x[maxn],int y[maxn]){
for(int i=0;i<n*n;i++){
x[i]+=y[i];
x[i]%=m;
}
}
void prt(){
for(int i=0;i<n*n;i++){
cout<<ans[i]<<' ';
if(i%n==n-1) cout<<'\n';
}
}
void mul(int x[maxn],int y[maxn]){
int temp[maxn];
memset(temp,0,sizeof(temp));
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
for(int k=0;k<n;k++){
temp[i*n+j]+=(x[i*n+k]*y[k*n+j])%m;
temp[i*n+j]%=m;
}
}
}
for(int i=0;i<n*n;i++) x[i]=temp[i];
}
void cpy(int x[maxn],int y[maxn]){
for(int i=0;i<n*n;i++) x[i]=y[i];
}
void pre(){ //注意:a[1]对应2^0!
cpy(a[1],c);
cpy(b[1],c);
for(int i=2;i<=30;i++){
cpy(a[i],a[i-1]);
mul(a[i],a[i-1]);
cpy(b[i],b[i-1]);
mul(b[i],a[i-1]);
pls(b[i],b[i-1]);
}
}
void solve(){
int cnt=1;
int temp_cnt[maxn];
cpy(temp_cnt,a[0]);
for(int i=29;i>=0;i--){
if(k>=1<<i){
mul(b[i+1],temp_cnt);
pls(ans,b[i+1]);
mul(temp_cnt,a[i+1]);
k-=1<<i;
}
}
}
signed main(){
cin>>n>>k>>m;
for(int i=0;i<n*n;i++){
cin>>c[i];
if(i%n==i/n) a[0][i]=1;
else a[0][i]=0;
}
pre();
solve();
prt();
return 0;
}