【模拟赛|ZROI】01串(容斥,分治FFT)

题面

【模拟赛|ZROI】01串(容斥,分治FFT)_第1张图片
【模拟赛|ZROI】01串(容斥,分治FFT)_第2张图片
【模拟赛|ZROI】01串(容斥,分治FFT)_第3张图片

题解

前面的转化不重要,我就直接贴了(其实是因为我怎么努力都想不明白)
【模拟赛|ZROI】01串(容斥,分治FFT)_第4张图片
然后我们将每两个数中间加分割线(两端还有两个,总共 n + 1 n+1 n+1 个),每次选择了一个 01 01 01 后就顺便把分割线也删了。分割线删除的时间就是一个排列,每个 0 0 0 右边的分割线一定比左边的分割线早删, 1 1 1 相反, ? ? ? 随意。

所以我们就可以把 01 01 01 转化成排列中相邻两个数的相对大小限制< >。然后就是个经典题了。

对于一个排列,相邻两个数有大于或小于的限制,怎么做?

我们的做法是容斥。先保留所有的<符号,去掉>符号的限制,计算总方案数。这时一个>符号的限制不被满足,等价于原先的位置放上了<符号。我们根据这点容斥,令 d p [ i ] dp[i] dp[i] 表示考虑前 i i i 个位置的方案数。

我们枚举排列 1~i 中最后一个逆序位置 j ( p j > p j + 1 ) j(p_j>p_{j+1}) j(pj>pj+1) ,令 p r o [ i ] = ( − 1 ) i 之 前 > 符 号 的 个 数 pro[i]=(-1)^{i之前>符号的个数} pro[i]=(1)i c [ i ] c[i] c[i] 表示 i i i i + 1 i+1 i+1 之间的符号:
d p [ i ] = ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ⋅ ( p r o [ j + 1 ] ⋅ p r o [ i ] ) ⋅ ( i j ) = i ! ⋅ p r o [ i ] ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ⋅ p r o [ j + 1 ] j ! ⋅ 1 ( i − j ) ! dp[i]=\sum_{j’} dp[j]\cdot (pro[j+1]\cdot pro[i])\cdot {i\choose j}\\ =i!\cdot pro[i]\sum_{j’} \frac{dp[j]\cdot pro[j+1]}{j!}\cdot \frac{1}{(i-j)!} dp[i]=j<i,c[j]=>dp[j](pro[j+1]pro[i])(ji)=i!pro[i]j<i,c[j]=>j!dp[j]pro[j+1](ij)!1

我们用分治FFT(NTT)就好了,时间复杂度 O ( n log ⁡ 2 n ) O(n\log^2n) O(nlog2n)

CODE

#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 250005
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
#define ENDL putchar('\n')
#define FI first
#define SE second
int xchar() {
    static const int mxn = 1000000;
    static char b[mxn];
    static int pos = 0,len = 0;
    if(pos == len) pos = 0,len = fread(b,1,mxn,stdin);
    if(pos == len) return -1;
    return b[pos ++];
}
//#define getchar() xchar()
LL read() {
    LL f=1,x=0;int s = getchar();
    while(s<'0' || s>'9') {if(s<0)return -1;if(s=='-')f=-f;s=getchar();}
    while(s>='0'&&s<='9') {x = (x<<3)+(x<<1)+(s^48);s = getchar();}
    return f*x;
}
void putpos(LL x) {
    if(!x) return ;
    putpos(x/10); putchar((x%10)^48);
}
void putnum(LL x) {
    if(!x) {putchar('0');return ;}
    if(x<0) putchar('-'),x=-x;
    return putpos(x);
}
void AIput(LL x,int c) {putnum(x);putchar(c);}

const int MOD = 998244353;
int n,m,s,o,k;
int fac[MAXN],inv[MAXN],invf[MAXN];
char ss[MAXN];
int om,xm[MAXN<<2],rev[MAXN<<2];
int qkpow(int a,int b) {
	int res = 1; while(b > 0) {
		if(b & 1) res = res *1ll* a % MOD;
		a = a *1ll* a % MOD; b >>= 1;
	} return res;
}
void NTT(int *s,int n,int op) {
	for(int i = 1;i < n;i ++) {
		rev[i] = (rev[i>>1]>>1) | ((i&1) ? (n>>1):0);
		if(rev[i] < i) swap(s[rev[i]],s[i]);
	} om = qkpow(3,(MOD-1)/n); xm[0] = 1;
	if(op < 0) om = qkpow(om,MOD-2);
	for(int i = 1;i <= n;i ++) xm[i] = xm[i-1] *1ll* om % MOD;
	for(int k = 2,t = n>>1;k <= n;k <<= 1,t >>= 1) {
		for(int j = 0;j < n;j += k) {
			for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
				int A = s[i],B = s[i+(k>>1)];
				s[i] = (A + xm[l] *1ll* B) % MOD;
				s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
			}
		}
	}
	if(op < 0) {
		int iv = qkpow(n,MOD-2);
		for(int i = 0;i < n;i ++) s[i] = s[i] *1ll* iv % MOD;
	}return ;
}
int A[MAXN<<2],B[MAXN<<2];
int pro[MAXN],dp[MAXN];
int ST;
void solve(int l,int r) {
	if(l == r) return ;
	int md = (l + r) >> 1;
	solve(l,md);
	int le = 1;
	while(le <= (md-l)+(r-l)) le <<= 1;
	for(int i = 0;i < le;i ++) A[i] = B[i] = 0;
	for(int i = l;i <= md;i ++) {
		if(ss[i+1] == '1') A[i-l] = dp[i]*1ll*pro[i+1]%MOD*invf[i-ST+1]%MOD;
	}
	for(int i = 1;i <= r-l;i ++) B[i] = invf[i];
	NTT(A,le,1); NTT(B,le,1);
	for(int i = 0;i < le;i ++) A[i] = A[i] *1ll* B[i] % MOD;
	NTT(A,le,-1);
	for(int i = 0;i < le;i ++) {
		if(i+l > md && i+l <= r) {
			(dp[i+l] += fac[i+l-ST+1]*1ll*pro[i+l]%MOD*A[i]%MOD) %= MOD;
		}
	}
	solve(md+1,r);
	return ;
}
int main() {
	freopen("a.in","r",stdin);
	freopen("a.out","w",stdout);
	n = read();
	fac[0] = fac[1] = inv[0] = inv[1] = invf[0] = invf[1] = 1;
	for(int i = 2;i <= n+3;i ++) {
		fac[i] = fac[i-1] *1ll* i % MOD;
		inv[i] = (MOD - inv[MOD%i]) *1ll* (MOD/i) % MOD;
		invf[i] = invf[i-1] *1ll* inv[i] % MOD;
	}
	scanf("%s",ss + 1);
	pro[0] = 1;
	for(int i = 1;i <= n;i ++) {
		pro[i] = pro[i-1];
		if(ss[i] == '1') pro[i] = MOD-pro[i];
	}
	int ans = fac[n+1];
	for(int i = 0;i <= n;i ++) {
		int r = i;
		while(r < n && ss[r+1] != '?') r ++;
		for(int j = i;j <= r;j ++) {
			dp[j] = pro[j]*1ll*pro[i]%MOD;
		} ST = i;
		solve(i,r);
		ans = ans *1ll* dp[r] % MOD;
		ans = ans *1ll* invf[r-i+1] % MOD;
		i = r;
	}
	AIput(ans,'\n');
    return 0;
}

你可能感兴趣的:(分治,数学,C++,算法,容斥原理,快速傅里叶变换,分治)