题目描述:
luogu
题解:
FST+FWT。
FWT_OR+FWT_XOR+FWT_AND=快乐。
首先对于这个式子来说,第一部分是个子集卷积,后面都可以用FWT做。
学了一下FST,它的原理基本是这样的:
子集卷积即$f[i]=\sum\limits_{j \subseteq i}g[j]*h[i-j]$,看起来很像FWT_OR但是有另外条件,即$j\&(i-j)=0$。
vfk爷的论文中提到过一种方法,多加一维去除非法情况。
设$cnt[i]$表示$i$的二进制表示中$1$的个数,那么原来的$f[i]$变为$f[cnt[i]][i]$,
原式变为$f[cnt[i]][i] = \sum\limits_{j \subseteq i} g[cnt[j]][j] * h[cnt[i-j]][i-j]$。
(对于$f[i][j]$,只有当$i=cnt[j]$时才可能有值)
这时我们对于所有$f[i]$做FWT_OR,这样的话每一位上的值是$F[i][j] = \sum\limits_{k \subseteq j} f[i][k]$。
然后直接$F[i][j] = \sum\limits_{k <= i} G[k][j] * H[i-k][j]$。
最后逆变换回去就是$f[i][j]$了。
注意$i<2^{17}$可能会有$17$个$1$。
代码:
#include#include #include using namespace std; typedef long long ll; const int N = 1000050; const int M = (1<<17)+50; const int MOD = 1000000007; const int inv_2 = (MOD+1)/2; template inline void read(T&x) { T f = 1,c = 0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();} x = f*c; } template inline void Mod(T&x){if(x>=MOD)x-=MOD;} int n,len=1<<17,s[M],a[M],b[M],c[M],d[20][M],e[20][M],f[M],cnt[M]; void fwt_xor(int*a,int len,int k) { for(int i=1;i 1) for(int j=0;j 1)) for(int o=0;o) { int w1 = a[j+o],w2 = a[j+o+i]; Mod(a[j+o] = w1+w2),Mod(a[j+o+i] = w1+MOD-w2); if(k==-1)a[j+o]=1ll*a[j+o]*inv_2%MOD,a[j+o+i]=1ll*a[j+o+i]*inv_2%MOD; } } void fwt_or(int*a,int len,int k) { for(int i=1;i 1) for(int j=0;j 1)) for(int o=0;o) { if(k==1)Mod(a[j+o+i]+=a[j+o]); else Mod(a[j+o+i]+=MOD-a[j+o]); } } void fwt_and(int*a,int len,int k) { for(int i=1;i 1) for(int j=0;j 1)) for(int o=0;o) { if(k==1)Mod(a[j+o]+=a[j+o+i]); else Mod(a[j+o]+=MOD-a[j+o+i]); } } int main() { freopen("tt.in","r",stdin); read(n); for(int x,i=1;i<=n;i++) read(x),s[x]++; for(int i=1;i ) cnt[i]=cnt[i-(i&-i)]+1; f[1] = 1; for(int i=2;i ) Mod(f[i] = f[i-1]+f[i-2]); for(int i=0;i ) d[cnt[i]][i] = s[i]; for(int i=0;i<=17;i++) fwt_or(d[i],len,1); for(int i=0;i<=17;i++) for(int j=0;j<=i;j++) for(int k=0;k ) Mod(e[i][k]+=1ll*d[j][k]*d[i-j][k]%MOD); for(int i=0;i<=17;i++) fwt_or(e[i],len,-1); for(int i=0;i ) a[i] = 1ll*e[cnt[i]][i]*f[i]%MOD; for(int i=0;i ) b[i] = 1ll*s[i]*f[i]%MOD; fwt_xor(s,len,1); for(int i=0;i ) c[i] = 1ll*s[i]*s[i]%MOD; fwt_xor(c,len,-1); for(int i=0;i ) c[i] = 1ll*c[i]*f[i]%MOD; fwt_and(a,len,1),fwt_and(b,len,1),fwt_and(c,len,1); for(int i=0;i ) a[i] = 1ll*a[i]*b[i]%MOD*c[i]%MOD; fwt_and(a,len,-1); int ans = 0; for(int i=1;i 1) Mod(ans+=a[i]); printf("%d\n",ans); return 0; }