【NTT】 HDOJ 5322 Hope

推出公式以后就可以NTT计算了。。。。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

const int mod = 998244353;
const int maxn = 300005;

LL f[maxn];
LL g[maxn];
LL dp[maxn];
LL A[maxn];
LL B[maxn];
int m1[maxn], m2[maxn];
int Top;

LL powmod(LL a, LL b)
{
	LL res = 1, base = a;
	while(b) {
		if(b % 2) res = res * base % mod;
		base = base * base % mod;
		b /= 2;
	}
	return res;
}

void NTT(LL *a, int n, int on) {
	for(int i = 1, j = 0; i < n; i++) {
		for(int k = n>>1; k > (j^=k); k >>= 1);
		if(i < j) swap(a[i], a[j]);
	}

	for(int i = 0; i < n; i++) m2[i] = m1[i*(Top/n)];
	for(int m = 2; m <= n; m <<= 1) {
		int wm = on == 1 ? n-n/m : n/m;
		for(int i = 0; i < n; i += m) {
			int w = 0;
			for(int j = i; j < i+m/2; j++) {
				int t = 1ll*m2[w]*a[j+m/2]%mod;
				a[j+m/2] = a[j]-t;
  				if(a[j+m/2] < 0) a[j+m/2] += mod;
				a[j] = a[j]+t;
				if(a[j] >= mod) a[j] -= mod;
				w += wm;
				if(w >= n) w -= n;
			}
		}
	}
	if(on == -1) {
		LL Inv = powmod(n, mod-2);
		for(int i = 0; i < n; i++) a[i] = 1ll*a[i]*Inv%mod;
	}
}

void cdq(int L, int R)
{
	if(L == R) return;
	
	int n = 1, mid = (L + R) >> 1;
	cdq(L, mid);
	
	while(n <= (R - L + 1) * 2) n <<= 1;
	
	for(int i = 0; i < n; i++) A[i] = B[i] = 0;
	
	for(int i = L; i <= mid; i++) A[i - L] = dp[i] * g[i] % mod;
	for(int i = 1; i <= n/2; i++) B[i] = (LL)i * i % mod;
	NTT(A, n, 1);
	NTT(B, n, 1);
	for(int i = 0; i < n; i++) A[i] = A[i] * B[i] % mod;
	NTT(A, n, -1);
	for(int i = mid+1; i <= R; i++) {
		dp[i] = (dp[i] + f[i-1] * A[i - L] % mod) % mod;
	}
	cdq(mid+1, R);
}

void init(int n)
{
	Top = powmod(2, 18);
	f[0] = m1[0] = 1;
	LL t = powmod(3, (mod - 1) / Top);
	for(int i = 1; i < Top; i++) m1[i] = m1[i-1] * t % mod;
	for(int i = 1; i <= n; i++) f[i] = f[i-1] * i % mod;
	g[n] = powmod(f[n], mod - 2);
	for(int i = n-1; i >= 0; i--) g[i] = g[i+1] * (i + 1) % mod;
	dp[0] = 1;
	cdq(0, n);
}

int main()
{
	init(100000);
	int n;
	while(scanf("%d", &n) !=EOF) printf("%lld\n", dp[n]);

	return 0;
}


你可能感兴趣的:(ntt)