HDU 6470 Count(矩阵快速幂)

题意:

第一头牛和第二头牛的号码为 1 和 2,其余的牛的号码需要通过公式求得:

f(n) = f(n - 1) + 2 * f(n - 2) + n * n * n , (f(1) = 1,f(2) = 2,分别是第一头和第二头)。

题解:

与前几项有关的加法公式一般都是用矩阵快速幂来解的。。。所以解法就是矩阵快速幂。

很明显,右矩阵肯定与 f(n - 1)、f(n - 2) 和 n * n * n。

问题来了,怎么把 n * n * n 转移为 (n + 1) * (n + 1) * (n + 1)?直接算。。。可以得到 :

n * n * n + 3 * n * n - 3 * n + 1 = (n + 1) * (n + 1) * (n + 1),所以我们还要有 n * n 、 n 和 1,同理可以得到:

n * n  + 2 * n + 1 = (n + 1) * (n + 1)。

所以右矩阵就这么可以得出来了:

\begin{bmatrix}f(n-1) \\ f(n-2) \\ n^{3} \\ n^{2} \\ n \\1 \end{bmatrix}

左矩阵根据关系可以得:

\begin{bmatrix} 1 & 2 & 1 & 0 & 0 & 0\\ 1 & 0 & 0 & 0 & 0 & 0\\ 0 & 0 & 1 & 3 & -3 & 1\\ 0 & 0 & 0 & 1 & 2 & 1\\ 0 & 0 & 0 & 0 & 1 & 1\\ 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix}

推出来就可以直接对左矩阵快速幂,直接输出。

AC代码:

#include 
#include  
#include   
#include   
#include    
#include    
#include    
#include    
#include     
#include     
#include     
#include     
#include      
#include       
#include       
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define line printf("---------------------------\n")
#define mem(a, b) memset(a, b, sizeof(a))
#define pi acos(-1)
using namespace std;
typedef long long ll;
const double eps = 1e-9;
const int inf = 0x3f3f3f3f;
const int mod = 123456789;
const int maxn = 2000+10;

struct node {
	ll mrx[6][6];
	node(){}
	node(ll mapp[6][6]) {
		for(int i = 0; i < 6; i++) {
			for(int j = 0; j < 6; j++) {
				mrx[i][j] = mapp[i][j];
			}
		}
	}
	node(ll x) {
		for(int i = 0; i < 6; i++) {
			for(int j = 0; j < 6; j++) {
				if(i == j) {
					mrx[i][j] = x;
				} else {
					mrx[i][j] = 0;
				} 
			}
		}
	}
	friend node operator * (node a, node b) {
		node ans;
		for(int i = 0; i < 6; i++) {
			for(int j = 0; j < 6; j++) {
				ll p = 0;
				for(int k = 0; k < 6; k++) {
					p = (p + (a.mrx[i][k] * b.mrx[k][j]) % mod) % mod;
				}
				ans.mrx[i][j] = p;
			}
		}
		return ans;
	}
	void print() {
		for(int i = 0; i < 6; i++) {
			for(int j = 0; j < 6; j++) {
				cout << mrx[i][j] << " ";
			}
			cout << endl;
		}
	}
};

node qpow(node a, ll b) {
	node sum(1);
	while(b) {
		if(b & 1) {
			sum = sum * a;
		}
		a = a * a;
		b >>= 1;
	}
	return sum;
}

int main() {
	ll mapp[6][6] = {{1,2,1,0,0,0},
					  {1,0,0,0,0,0},
				   	  {0,0,1,3,-3,1},
					  {0,0,0,1,2,1},
					  {0,0,0,0,1,1},
					  {0,0,0,0,0,1}};
	node temp(mapp);
	int t;
	cin >> t;
	while(t--) {
		ll n;
		cin >> n;
		node m = qpow(temp, n - 2);
		ll ans[6] = {2,1,27,16,4,1};
		for(int i = 0; i < 6; i++) {
			ll sum = 0;
			for(int j = 0; j < 6; j++) {
				sum = (sum + (ans[j] * m.mrx[i][j]) % mod) % mod;
			}
			ans[i] = sum;
		}
		cout << ans[0] << endl;
	}
}

 

你可能感兴趣的:(HDU题解)