非常好的dp+组合题
这个版本的做法参考了tourist的editorial
我们不考虑两个序列的random shuffle,而是考虑这样的两个操作
1. 确定a序列和b序列的匹配方法
2. 确定这些匹配方法的出现顺序
我们考虑a序列和b序列匹配好以后,在A序列里面每个ai向bi连一条有向边
我们发现A序列的每个位置只有三种情况
1. 有某个a对应没有b对应,这样这个点只有出边
2. 由某个a对应也有某个b对应,这样这个点有入边也有出边
3. 没有a对应有某个b对应,这样这个点只有入边
设1类点有e个,2类点有m个,因为A和B中1的个数相等,所以3类点也有e个
考虑这个图本身,我们发现这个图一定是若干个环和若干个链组成的
我们发现组成环的这些位置所对应的A序列中的位置的数一定都是1,所以环当中边出现的顺序是任意的
我们发现链中匹配边出现的顺序有且只有一种,因为一条链中对应的A序列中的值,只有链尾是0,其他都是1,对应B序列中只有链头是0,其他都是1,所以边
一定要按照从后向前的顺序出现
环中的点都是2类点,链头和链尾都是1类和3类点,链中间的点是2类点
我们考虑怎样把2类点分到e条链和若干个环中
令dp[i][j]表示已经考虑到将j个点放入i条链的方案数, dp[i][j]=∑jk=0dp[i−1][k](j−k+1)! d p [ i ] [ j ] = ∑ k = 0 j d p [ i − 1 ] [ k ] ( j − k + 1 ) ! (分母上的阶乘的意义在后面解释)
最后把dp[e][0~m]的答案加起来,再乘上 e!∗m!∗(e+m)! e ! ∗ m ! ∗ ( e + m ) !
e!指e条链的链头和链尾配对,有e!种配对方法
m!指m个2类点的连接顺序,比如说一条链的点确定了,但这条链的连法有阶乘种
我们还要考虑边的出现顺序,所有的出现顺序是(e+m)!但是每条链的出现顺序只有一种所以要除以若干个(u+1)!,这个在算dp的时候已经除过了
这样就有了一个 O(n3) O ( n 3 ) 的做法
考虑优化
我们发现dp的转移方程是一个卷积的形式,k+(j-k+1)=j+1,所以每层的转移可以NTT优化,复杂度降到 O(n2logn) O ( n 2 l o g n )
然后我们发现每次乘的多项式都是一样的,都是 f(x)=∑mi=0xi(i+1)! f ( x ) = ∑ i = 0 m x i ( i + 1 ) ! ,所以可以快速幂+NTT,复杂度 O(nlog2n) O ( n l o g 2 n )
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;
const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=3.14159265;
const int G=3;
inline int getint()
{
char ch;int res;bool f;
while (!isdigit(ch=getchar()) && ch!='-') {}
if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
while (isdigit(ch=getchar())) res=res*10+ch-'0';
return f?res:-res;
}
int inv[200048];
LL finv[200048],fac[200048];
int n;
char s1[100048],s2[100048];
int e,m;
inline void init_inv()
{
int i;
fac[0]=fac[1]=inv[0]=inv[1]=finv[0]=finv[1]=1;
for (i=2;i<=100000;i++)
{
fac[i]=(fac[i-1]*i)%MOD;
inv[i]=MOD-((long long)(MOD/i)*inv[MOD%i])%MOD;
finv[i]=(finv[i-1]*inv[i])%MOD;
}
}
inline LL quick_pow(LL x,LL y)
{
x%=MOD;LL res=1;
while (y)
{
if (y&1) res=(res*x)%MOD,y--;
x=(x*x)%MOD;y>>=1;
}
return res;
}
int len;
LL wn_pos[100048],wn_neg[100048];
inline void init_wn()
{
for (register int clen=2;clen<=len;clen<<=1)
{
wn_pos[clen]=quick_pow(G,(MOD-1)/clen);
wn_neg[clen]=quick_pow(G,(MOD-1)-(MOD-1)/clen);
}
}
LL a[100048],b[100048];
inline void NTT(LL c[],int fl)
{
int i,j,k,clen;
for (i=(len>>1),j=1;jif (ifor (k=(len>>1);i&k;k>>=1) i^=k;
i^=k;
}
for (clen=2;clen<=len;clen<<=1)
{
LL wn=(fl==1?wn_pos[clen]:wn_neg[clen]);
for (j=0;j1;
for (k=j;k>1);k++)
{
LL tmp1=c[k],tmp2=(c[k+(clen>>1)]*w)%MOD;
c[k]=(tmp1+tmp2)%MOD;c[k+(clen>>1)]=((tmp1-tmp2)%MOD+MOD)%MOD;
w=(w*wn)%MOD;
}
}
}
if (fl==-1)
for (i=0;iinline void calc_NTT()
{
NTT(a,1);NTT(b,1);
for (register int i=0;i1);
}
struct poly
{
LL A[100048];
inline poly operator * (const poly B) const
{
int i;poly res;
memset(a,0,sizeof(a));memset(b,0,sizeof(b));
for (i=0;i<=m;i++) a[i]=A[i],b[i]=B.A[i];
calc_NTT();
for (i=0;i<=m;i++) res.A[i]=a[i];
return res;
}
};
inline poly Quick_pow(poly x,LL y)
{
int i;poly res;
for (i=0;i<=m;i++) x.A[i]%=MOD;
res.A[0]=1;
while (y)
{
if (y&1) res=res*x,y--;
x=x*x;y>>=1;
}
return res;
}
int main ()
{
int i,j,k;
scanf("%s%s",s1+1,s2+1);n=strlen(s1+1);
init_inv();
e=m=0;
for (i=1;i<=n;i++)
{
if (s1[i]=='1' && s2[i]=='0') e++;
if (s1[i]=='1' && s2[i]=='1') m++;
}
len=1;while (len<=m*2) len<<=1;
init_wn();
poly ans;for (i=0;i<=m;i++) ans.A[i]=finv[i+1];
ans=Quick_pow(ans,e);
LL fans=0;
for (i=0;i<=m;i++) fans=(fans+ans.A[i])%MOD;
fans=fans*fac[e]%MOD*fac[m]%MOD*fac[e+m]%MOD;
printf("%lld\n",fans);
return 0;
}