poj 3233 Matrix Power Series (矩阵快速幂 + 二分)

http://poj.org/problem?id=3233

题意:  

题意:已知一个n*n的矩阵A,和一个正整数k,求 S A A2 A3 + … + Ak

 

/*第一次写时,写挫啦,tle 一次,后来,稍微改动了一下,ac

矩阵快速幂。首先我们知道 A^x 可以用矩阵快速幂求出来。
其次可以对k进行二分,
每次将规模减半,分k为奇偶两种情况,如当k = 6和k = 7时有:
 S(6) = (1 + A^3) * (A + A^2 + A^3) = (1 + A^3) * S(3)。
 s(7)  = (1 + A^3) * (A + A^2 + A^3) + A^7 = (1 + A^3)*(s(3)) + A^7;
*/

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include< set>
#include<map>
#include<queue>
#include<vector>
#include< string>
#define Min(a,b) a<b?a:b
#define Max(a,b) a>b?a:b
#define CL(a,num) memset(a,num,sizeof(a));
#define maxn  40
#define eps  1e-6
#define inf 9999999
#define mx 1<<60

using  namespace std;
struct   martrix
{
     int m[ 31][ 31];

};
int n,mod,k;
 martrix  mtadd (martrix a,martrix b)
   {
        int i ,j;
       martrix c;
        for(i =  0;i < n;i++)
       {

            for(j =  0 ; j< n;j++)
           {
               c.m[i][j] =  0;
               c.m[i][j] = (a.m[i][j] + b.m[i][j])%mod;

           }
       }
        return c;
   }

martrix mtmul(martrix a,martrix b)
{
    martrix c;
     int i,j,k;
     for(i =  0; i < n; i++)
    {
         for(j =  0; j < n;j++)
        {
            c.m[i][j] =  0;
             for(k =  0 ; k < n;k++)
            {
                c.m[i][j] += a.m[i][k] * b.m[k][j];
                c.m[i][j] %=mod;
            }
        }
    }


     return c;
}
martrix mtpow(martrix d, int k)
{   martrix a;
     if(k ==  1return d ;
     int mid = k /  2;
    a = mtpow(d,k/ 2);
    a = mtmul(a,a);
     if(k &  1)
    {
        a = mtmul(a,d);
    }
     return a;


}
martrix solve(martrix a, int k)
{
    martrix b,c,d;
     if(k ==  1return a ;
     int mid = k /  2;
    b = mtpow(a,mid);
    d = solve(a,mid);

    c = mtmul(b,d) ;
    c = mtadd(c,d);

     if(k& 1)
    {
        c = mtadd(mtpow(a,k),c);
    }
     return  c;
}
int main()
{

    martrix a,b;

      int i,j;
     while(scanf( " %d%d%d ",&n,&k,&mod)!=EOF)
    {
         for(i =  0; i < n;i++)
        {
             for(j =  0; j < n;j++)
             scanf( " %d ",&a.m[i][j]);
        }
        b = solve(a,k);
         for(i = 0 ; i < n;i++)
        {
             for(j =  0; j <n ;j++)
            {
                 if(j ==  0)printf( " %d ",b.m[i][j] % mod);
                 else printf( "  %d ",b.m[i][j] % mod);
            }
            printf( " \n ");
        }
    }
}

你可能感兴趣的:(Matrix)