【题目】
原题地址
给定两个长度为 2 n 2^n 2n的序列 a , b a,b a,b,求他们的子集卷积,每一位答案%4。 n ≤ 21 n\leq 21 n≤21
【解题思路】
十分玄学,出题人这个想法十分厉害。(集训队论文里没有提过呢)
先说 FWT \text{FWT} FWT的做法:设 f ( i ) f(i) f(i)为 i i i的 p o p c o u n t popcount popcount
令 A i = a i ⋅ 4 f ( i ) , B i = b i ⋅ 4 f ( i ) A_i=a_i\cdot 4^{f(i)},B_i=b_i\cdot 4^{f(i)} Ai=ai⋅4f(i),Bi=bi⋅4f(i)
计算 C i = ∑ j ∣ k = i A j ⋅ B k C_i=\sum_{j|k=i}A_j\cdot B_k Ci=∑j∣k=iAj⋅Bk,则 c i = C i 4 f ( i ) % 4 c_i=\frac {C_i} {4^{f(i)}}\% 4 ci=4f(i)Ci%4
复杂度是 O ( n 2 n ) O(n2^n) O(n2n)的。
我们来分析一下这个做法,可以这样来理解:
这里相当于每个值是一个小的多项式,多项式的指数是集合的大小,每次乘完得到一个多项式,而我们只要这个多项式的最低位。所以我们将最低位移到 0 0 0次项然后取模,其他贡献就没了。
更具体的,我们要做的就是将 f ( i ) + f ( j ) = = f ( i ∣ j ) f(i)+f(j)==f(i|j) f(i)+f(j)==f(i∣j)与 f ( i ) + f ( j ) > f ( i ∣ j ) f(i)+f(j)>f(i|j) f(i)+f(j)>f(i∣j)区分开,那么我们乘上 4 f ( i ) 4^{f(i)} 4f(i)后,对于后者,我们至少会多乘上一个4,于是这样就消除了影响。
另外,我们可以考虑用vfk论文里面的形式幂级数来做。首先直接做子集卷积,暴力显然是 O ( n 2 2 n ) O(n^22^n) O(n22n)的。
这里我们令 A S , k A_{S,k} AS,k表示对于 a a a,集合为 S S S,集合大小为 k k k的值, B S , k B_{S,k} BS,k同理。
那么只有 A S , ∣ S ∣ = a S A_{S,|S|}=a_S AS,∣S∣=aS是有效的,其余位置都是无效的,我们限制集合大小来保证转化为集合并卷积后不会多算。
f X , k = ∑ S ∪ T = X ∑ i + j = k A S , i B T , j f_{X,k}=\sum_{S\cup T=X} \sum_{i+j=k} A_{S,i}B_{T,j} fX,k=S∪T=X∑i+j=k∑AS,iBT,j
答案就是每个 f S , ∣ S ∣ f_{S,|S|} fS,∣S∣
我们把第二维看成一个形式幂级数,就有:
f X ( x ) = ∑ S ∪ T = X A S ( x ) B T ( x ) f_X(x)=\sum_{S\cup T=X} A_S(x) B_T(x) fX(x)=S∪T=X∑AS(x)BT(x)
答案就是 [ x ∣ S ∣ ] f S ( x ) [x^{|S|}]f_S(x) [x∣S∣]fS(x)
形式幂级数的加法是 O ( n ) O(n) O(n)的,乘法是 O ( n 2 ) O(n^2) O(n2)的, FMT \text{FMT} FMT是 O ( n 2 n ) O(n2^n) O(n2n)的,总复杂度仍然没有改变。
但由于这里模数特殊,我们可以将系数压到一个 unsigned long long \text{unsigned long long} unsigned long long里处理,两位恰好可以存一个系数。
这样我们就可以 O ( 1 ) O(1) O(1)做加法和乘法,总复杂度就是 O ( n 2 n ) O(n2^n) O(n2n)的了。
这里似乎会有进位的影响?然而实际上并不需要考虑,因为 [ x i ] A S ( x ) [x^i]A_S(x) [xi]AS(x)和 [ x i ] B T ( x ) [x^i]B_T(x) [xi]BT(x)会贡献到 [ x i + j ] f S ∪ T [x^{i+j}]f_{S\cup T} [xi+j]fS∪T,总是有 i + j ≥ ∣ S ∪ T ∣ i+j\geq |S\cup T| i+j≥∣S∪T∣,我们需要的是 [ x ∣ S ∪ T ∣ ] f S ∪ T [x^{|S\cup T|}]f_{S\cup T} [x∣S∪T∣]fS∪T,进位不会对最低位产生影响。
这样子来看的话,思想上又和前面 FWT \text{FWT} FWT的做法统一了起来,十分有趣。
也可以看看出题人对此的理解
【参考代码】
我的 FWT \text{FWT} FWT
#include
using namespace std;
typedef long long ll;
typedef long double ldb;
const int N=(1<<21)+5;
int Log,n,bit[N];
ll mod,a[N],b[N],c[N];
char s[N],t[N];
void fwtor(ll *a,int n,int op)
{
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=i<<1)
for(int k=0;k<i;++k)
if(op==1) a[i+j+k]=(a[j+k]+a[i+j+k])%mod;
else a[i+j+k]=(a[i+j+k]-a[j+k]+mod)%mod;
}
ll qmul(ll a,ll b)
{
ll tmp=(ldb)a*b/mod;
return ((a*b-tmp*mod)%mod+mod)%mod;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("CF1034E.in","r",stdin);
freopen("CF1034E.out","w",stdout);
#endif
scanf("%d",&Log);n=1<<Log;mod=(ll)1<<(2*Log+2);
scanf("%s%s",s,t);
for(int i=1;i<n;++i) bit[i]=bit[i>>1]+(i&1);
for(int i=0;i<n;++i)
{
a[i]=(ll)(s[i]-'0')<<(bit[i]*2);
b[i]=(ll)(t[i]-'0')<<(bit[i]*2);
}
fwtor(a,n,1);fwtor(b,n,1);
for(int i=0;i<n;++i) c[i]=qmul(a[i],b[i]);
fwtor(c,n,-1);
for(int i=0;i<n;++i) c[i]>>=(bit[i]*2),putchar((c[i]&3)^48);
return 0;
}
别人的 FMT \text{FMT} FMT
#include
#include
#include
#include
#include
using namespace std;
#define ull unsigned long long
inline char read() {
static const int IN_LEN = 1000000;
static char buf[IN_LEN], *s, *t;
return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x) {
static bool iosig;
static char c;
for (iosig=false, c=read(); !isdigit(c); c=read()) {
if (c == '-') iosig=true;
if (c == -1) return;
}
for (x=0; isdigit(c); c=read()) x=((x+(x<<2))<<1)+(c^'0');
if (iosig) x=-x;
}
const int OUT_LEN = 1<<21;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }
const int N = 1<<21, M = 22;
int n, cnt[N];
ull a[N], b[N];
int main() {
read(n);
for(int i=1; i<1<<n; ++i) cnt[i]=cnt[i^(i&-i)]+2;
char x;
while(isspace(x=read()));
for(int i=0; i<1<<n; ++i) a[i]=(ull)(x-'0')<<cnt[i], x=read();
while(isspace(x=read()));
for(int i=0; i<1<<n; ++i) b[i]=(ull)(x-'0')<<cnt[i], x=read();
for(int i=0; i<n; ++i) for(int j=0; j<1<<n; ++j) if(j>>i&1)
a[j]+=a[j^(1<<i)], b[j]+=b[j^(1<<i)];
for(int i=0; i<1<<n; ++i) a[i]*=b[i];
for(int i=0; i<n; ++i) for(int j=0; j<1<<n; ++j) if(j>>i&1) a[j]-=a[j^(1<<i)];
for(int i=0; i<1<<n; ++i) print((char)('0'+(a[i]>>cnt[i]&3)));
return flush(), 0;
}