link
给两个01串 A , B ,它们 1 的个数相同,记 ai 表示 A 中 1 的出现位置, bi 表示 B 中 1 的出现位置,将 a 和 b 打乱后依次交换 Aai 和 Abi ,求有多少种方式使得 A=B 。
我们先考虑 a 和 b 的匹配,再考虑顺序。
对于一个给定的匹配,我们从 ai 到 bi 连一条有向边,注意到每个点入度和出度不超过 1 ,所以这个图一定由若干个环和若干条链构成,环的顺序随意而链的顺序唯一。
假设有 x 个 Ai=Bi=1 , y 个 Ai=1,Bi=0 ,我们发现边数是 x+y ,图由 y 条链和若干个环构成。
我们考虑将 Ai=Bi=1 的点分配到链上面,记 f(i,j) 表示考虑前 i 条链,分配了 j 个点的方案数,那么有转移:
f(i,j)=∑f(i−1,j−k)(k+1)!
最后答案是:
∑f(y,i)×x!×y!×(x+y)! 。
其中 x! 表示分配点的顺序, y! 表示匹配个数, (x+y)! 表示边的顺序方案数。
注意到 f(y) 实际上是一个多项式的 y 次方,直接快速幂即可。
#include
#define xx first
#define yy second
#define mp make_pair
#define pb push_back
#define mset(x, y) memset(x, y, sizeof x)
#define mcpy(x, y) memcpy(x, y, sizeof x)
using namespace std;
typedef long long LL;
typedef pair <int, int> pii;
inline int Read()
{
int x = 0, f = 1, c = getchar();
for (; !isdigit(c); c = getchar())
if (c == '-')
f = -1;
for (; isdigit(c); c = getchar())
x = x * 10 + c - '0';
return x * f;
}
const int MAXN = 32770;
const int MAXM = 20;
const int mod = 998244353;
const int G = 3;
namespace FFT
{
int wn[MAXM], R[MAXN], L, n;
inline int Qow(int x, int y)
{
int ret = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1)
ret = 1LL * ret * x % mod;
return ret;
}
inline int Inv(int x)
{
return Qow(x, mod - 2);
}
inline void Ini(int len)
{
for (n = 1, L = 0; n < len; n <<= 1, L ++);
for (int i = 1; i < n; i ++)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << L - 1);
wn[0] = Qow(G, mod - 1 >> L);
for (int i = 1; i < L; i ++)
wn[i] = 1LL * wn[i - 1] * wn[i - 1] % mod;
}
inline void Ini_Inv()
{
wn[0] = Inv(wn[0]);
for (int i = 1; i < L; i ++)
wn[i] = 1LL * wn[i - 1] * wn[i - 1] % mod;
}
inline void FFT(int *x)
{
for (int i = 0; i < n; i ++)
if (i < R[i])
swap(x[i], x[R[i]]);
for (int i = 1, l = L - 1; i < n; i <<= 1, l --)
for (int j = 0; j < n; j += i << 1)
for (int k = 0, w = 1, u, v; k < i; k ++, w = 1LL * w * wn[l] % mod)
u = x[j + k], v = 1LL * x[i + j + k] * w % mod, x[j + k] = (u + v) % mod, x[i + j + k] = (u + mod - v) % mod;
}
inline void Cpy(int *x, int n, int *y, int m)
{
for (int i = 0; i < n; i ++)
y[i] = x[i];
for (int i = n; i < m; i ++)
y[i] = 0;
}
inline void Mul(int *a, int n, int *b, int m, int *ret)
{
static int x[MAXN], y[MAXN];
Ini(n + m - 1);
Cpy(a, n, x, FFT::n); Cpy(b, m, y, FFT::n);
FFT(x); FFT(y);
for (int i = 0; i < FFT::n; i ++)
x[i] = 1LL * x[i] * y[i] % mod;
Ini_Inv();
FFT(x);
int v = Inv(FFT::n);
for (int i = 0; i < FFT::n; i ++)
ret[i] = 1LL * v * x[i] % mod;
}
}
int n, m, fac[MAXN], inv[MAXN], x[MAXN], r[MAXN];
char s[MAXN], t[MAXN];
int main()
{
#ifdef wxh010910
freopen("data.in", "r", stdin);
#endif
scanf("%s%s", s + 1, t + 1);
for (int i = 1, l = strlen(s + 1); i <= l; i ++)
if (s[i] == '1' && t[i] == '1')
n ++;
else if (s[i] == '1' && t[i] == '0')
m ++;
r[0] = fac[0] = inv[0] = fac[1] = inv[1] = 1;
for (int i = 2; i <= n + m + 1; i ++)
fac[i] = 1LL * fac[i - 1] * i % mod, inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
for (int i = 2; i <= n + m + 1; i ++)
inv[i] = 1LL * inv[i] * inv[i - 1] % mod;
for (int i = 0; i <= n; i ++)
x[i] = inv[i + 1];
for (int t = m; t; t >>= 1, FFT::Mul(x, n + 1, x, n + 1, x))
if (t & 1)
FFT::Mul(x, n + 1, r, n + 1, r);
int ret = 0;
for (int i = 0; i <= n; i ++)
ret = (ret + r[i]) % mod;
return printf("%d\n", 1LL * ret * fac[n] % mod * fac[m] % mod * fac[n + m] % mod), 0;
}