前面的转化不重要,我就直接贴了(其实是因为我怎么努力都想不明白)
然后我们将每两个数中间加分割线(两端还有两个,总共 n + 1 n+1 n+1 个),每次选择了一个 01 01 01 后就顺便把分割线也删了。分割线删除的时间就是一个排列,每个 0 0 0 右边的分割线一定比左边的分割线早删, 1 1 1 相反, ? ? ? 随意。
所以我们就可以把 01 01 01 转化成排列中相邻两个数的相对大小限制<
>
。然后就是个经典题了。
对于一个排列,相邻两个数有大于或小于的限制,怎么做?
我们的做法是容斥。先保留所有的<
符号,去掉>
符号的限制,计算总方案数。这时一个>
符号的限制不被满足,等价于原先的位置放上了<
符号。我们根据这点容斥,令 d p [ i ] dp[i] dp[i] 表示考虑前 i i i 个位置的方案数。
我们枚举排列 1~i 中最后一个逆序位置 j ( p j > p j + 1 ) j(p_j>p_{j+1}) j(pj>pj+1) ,令 p r o [ i ] = ( − 1 ) i 之 前 > 符 号 的 个 数 pro[i]=(-1)^{i之前>符号的个数} pro[i]=(−1)i之前>符号的个数 , c [ i ] c[i] c[i] 表示 i i i 和 i + 1 i+1 i+1 之间的符号:
d p [ i ] = ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ⋅ ( p r o [ j + 1 ] ⋅ p r o [ i ] ) ⋅ ( i j ) = i ! ⋅ p r o [ i ] ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ⋅ p r o [ j + 1 ] j ! ⋅ 1 ( i − j ) ! dp[i]=\sum_{j’} dp[j]\cdot (pro[j+1]\cdot pro[i])\cdot {i\choose j}\\ =i!\cdot pro[i]\sum_{j’} \frac{dp[j]\cdot pro[j+1]}{j!}\cdot \frac{1}{(i-j)!} dp[i]=j<i,c[j]=‘>’∑dp[j]⋅(pro[j+1]⋅pro[i])⋅(ji)=i!⋅pro[i]j<i,c[j]=‘>’∑j!dp[j]⋅pro[j+1]⋅(i−j)!1
我们用分治FFT(NTT)就好了,时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n) 。
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define MAXN 250005
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
#define ENDL putchar('\n')
#define FI first
#define SE second
int xchar() {
static const int mxn = 1000000;
static char b[mxn];
static int pos = 0,len = 0;
if(pos == len) pos = 0,len = fread(b,1,mxn,stdin);
if(pos == len) return -1;
return b[pos ++];
}
//#define getchar() xchar()
LL read() {
LL f=1,x=0;int s = getchar();
while(s<'0' || s>'9') {if(s<0)return -1;if(s=='-')f=-f;s=getchar();}
while(s>='0'&&s<='9') {x = (x<<3)+(x<<1)+(s^48);s = getchar();}
return f*x;
}
void putpos(LL x) {
if(!x) return ;
putpos(x/10); putchar((x%10)^48);
}
void putnum(LL x) {
if(!x) {putchar('0');return ;}
if(x<0) putchar('-'),x=-x;
return putpos(x);
}
void AIput(LL x,int c) {putnum(x);putchar(c);}
const int MOD = 998244353;
int n,m,s,o,k;
int fac[MAXN],inv[MAXN],invf[MAXN];
char ss[MAXN];
int om,xm[MAXN<<2],rev[MAXN<<2];
int qkpow(int a,int b) {
int res = 1; while(b > 0) {
if(b & 1) res = res *1ll* a % MOD;
a = a *1ll* a % MOD; b >>= 1;
} return res;
}
void NTT(int *s,int n,int op) {
for(int i = 1;i < n;i ++) {
rev[i] = (rev[i>>1]>>1) | ((i&1) ? (n>>1):0);
if(rev[i] < i) swap(s[rev[i]],s[i]);
} om = qkpow(3,(MOD-1)/n); xm[0] = 1;
if(op < 0) om = qkpow(om,MOD-2);
for(int i = 1;i <= n;i ++) xm[i] = xm[i-1] *1ll* om % MOD;
for(int k = 2,t = n>>1;k <= n;k <<= 1,t >>= 1) {
for(int j = 0;j < n;j += k) {
for(int i = j,l = 0;i < j+(k>>1);i ++,l += t) {
int A = s[i],B = s[i+(k>>1)];
s[i] = (A + xm[l] *1ll* B) % MOD;
s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD;
}
}
}
if(op < 0) {
int iv = qkpow(n,MOD-2);
for(int i = 0;i < n;i ++) s[i] = s[i] *1ll* iv % MOD;
}return ;
}
int A[MAXN<<2],B[MAXN<<2];
int pro[MAXN],dp[MAXN];
int ST;
void solve(int l,int r) {
if(l == r) return ;
int md = (l + r) >> 1;
solve(l,md);
int le = 1;
while(le <= (md-l)+(r-l)) le <<= 1;
for(int i = 0;i < le;i ++) A[i] = B[i] = 0;
for(int i = l;i <= md;i ++) {
if(ss[i+1] == '1') A[i-l] = dp[i]*1ll*pro[i+1]%MOD*invf[i-ST+1]%MOD;
}
for(int i = 1;i <= r-l;i ++) B[i] = invf[i];
NTT(A,le,1); NTT(B,le,1);
for(int i = 0;i < le;i ++) A[i] = A[i] *1ll* B[i] % MOD;
NTT(A,le,-1);
for(int i = 0;i < le;i ++) {
if(i+l > md && i+l <= r) {
(dp[i+l] += fac[i+l-ST+1]*1ll*pro[i+l]%MOD*A[i]%MOD) %= MOD;
}
}
solve(md+1,r);
return ;
}
int main() {
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
n = read();
fac[0] = fac[1] = inv[0] = inv[1] = invf[0] = invf[1] = 1;
for(int i = 2;i <= n+3;i ++) {
fac[i] = fac[i-1] *1ll* i % MOD;
inv[i] = (MOD - inv[MOD%i]) *1ll* (MOD/i) % MOD;
invf[i] = invf[i-1] *1ll* inv[i] % MOD;
}
scanf("%s",ss + 1);
pro[0] = 1;
for(int i = 1;i <= n;i ++) {
pro[i] = pro[i-1];
if(ss[i] == '1') pro[i] = MOD-pro[i];
}
int ans = fac[n+1];
for(int i = 0;i <= n;i ++) {
int r = i;
while(r < n && ss[r+1] != '?') r ++;
for(int j = i;j <= r;j ++) {
dp[j] = pro[j]*1ll*pro[i]%MOD;
} ST = i;
solve(i,r);
ans = ans *1ll* dp[r] % MOD;
ans = ans *1ll* invf[r-i+1] % MOD;
i = r;
}
AIput(ans,'\n');
return 0;
}