http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3538
题意:给n个点,用4种颜色给n个点着色,其中有m个点的颜色已经确定,要求每两个相邻的点着不同的颜色。问共有多少种方法。
算法1:DP, 由于每个位置着什么颜色只与其前一个位置的颜色有关,因此可以考虑dp[i][j] , 到i个点,第i个点着j色,共有的种数。但是这里的n<=1e7,这样的算法显然会超时。
算法2 :由于n 的范围很大, 这里我们不得不去考虑用数学的方法求解,而不能去分析中间的过程(因此会超时)。因为正是每个已确定的点阻碍了我们的考虑,因此我们这里可以将整个线段的情况分开考虑。
情形1 :没有确定点,则 ans = 4 * 3^(n-1) ;
情形2 :n>=1
(1): 00..(L个没有确定的位置)..00X 这种情况的种数:3^L ;
(2):X1 00..(L个没有确定的位置)...00 X2 :这里又要分类讨论X1 ?= X2
为了讨论方便,我们记 F(n)为n个空位置,两端的元素相同时的种数,D(n)为n个空位置,两端的元素不相同时的种数。
1): X1 == X2,F(n) = 3 * F(n-2) + 6 * D(n-2) ; F(0) = 0 ; F(1) = 3 ;
2) : X1 != X2 , D(n) = 2 * F(n-2) + 7 * D(n-2) ; D(0) =1 ;D(1) = 2 ;
得出了上面的两个递推关系之后,下面我们的任务就是要分别求出F(n) 和 D(n) ,这里要母函数的方法求解。记:
A(x) = F(0) + F(1)*X + F(2)*X^2 + .......
B(x) = D(0) + D(1) *X + D(2) *X^2 + .......
我们的目标是求F(n) ,为了求F(n),我们可以先求出A(x) ,这样一种方法就是用A(x) 、B(x) 代人上面的两个递推关系式中,从而求解出A(x), B(x)。中间的过程这里就不详细展开了, 最后求得
F(n) = 3/4 * (3^L - (-1)^L) ; D(n) = (3^(L+1) + (-1)^L) / 4 ;
最后只要将每一块的种数相乘即可。
通过这题,了解了两种求解a ^ b 的log(b)的算法,
分别是:
long long cal( int n ) //时间复杂度也是log(n) { long long res=1; for (int t=0 ; n ; n>>=1 ,t++ ) if(n&1)res=res*pow3[t]%mod; return res; }这里的pow3[t] 是 3^2^t ;另外一种是分治的思想:
long long cal(int num) //分治的思想求a^b,时间复杂度为log(b) { if(num == 1) return 3 ; if(num == 0) return 1; long long res ; if(num&1) { long long res2 = cal(num/2) ; res = ((res2*res2) % mod) *3 % mod ; } else { long long res2 = cal(num/2) ; res = (res2*res2) % mod; } return res ; }本题的代码:
#include<stdio.h> #include<iostream> #include<algorithm> using namespace std; typedef unsigned long long ll; //取模运算,这里要是用long long 就不行,只能用unsigned long long,不懂。。 const ll mod_ = 1000000007ll ; const ll mod = 4*1000000007ll; /* long long pow3[100] ; long long cal( int n ) //另外的一种求3^n的方法,时间复杂度也是log(n) { //cout << n << endl; long long res=1; for (int t=0 ; n ; n>>=1 ,t++ ) if(n&1)res=res*pow3[t]%mod; //cout << res << endl; return res; } */ struct Node{ int pos ; char team ; friend bool operator < (Node a , Node b) { return a.pos < b.pos ; } }arr[15] ; long long cal(int num) //分治的思想求a^b,时间复杂度为log(b) { if(num == 1) return 3 ; if(num == 0) return 1; long long res ; if(num&1) { long long res2 = cal(num/2) ; res = ((res2*res2) % mod) *3 % mod ; } else { long long res2 = cal(num/2) ; res = (res2*res2) % mod; } return res ; } int main() { int n,m; char c; /* pow3[0]=3; for (int i=0 ; i<50 ;++i) { pow3[i+1]=pow3[i]*pow3[i]%mod; //printf("%lld %lld\n",pow3[i] * pow3[i] ,pow3[i+1]); } */ while(scanf("%d %d",&n,&m)!=EOF) { long long res = 1; if(m==0){ res = ( cal(n-1) * 4 ) % mod_ ; cout<<res<<endl ; continue ; } for(int i=0;i<m;i++) { scanf("%d%c%c" ,&arr[i].pos ,&c , &arr[i].team ); } sort(arr,arr+m); res = res * cal(arr[0].pos-1+n-arr[m-1].pos) % mod_ ; for( int i = 1 ; i < m ; i ++ ) { int L = arr[i].pos - arr[i-1].pos - 1; long long num ; if(arr[i].team == arr[i-1].team) { num = ((3*(cal(L) + (L&1)?1:(-1)))%mod )/4 ; //这里是 %mod ,一开始没有注意,不能 %mod_ ..以后要记住。。 res = res * num % mod_ ; } else if(arr[i].team != arr[i-1].team){ num = ((cal(L+1)-((L&1)?1:(-1)))%mod ) / 4 ; //同理。。 res = res * num % mod_ ; } } //printf("%lld\n",res%mod_); cout<<(res%mod_)<<endl ; } return 0; }