FWT学习小结

引入

有的时候我们需要进行这样的求和:
∑ x ⊗ y a [ x ] ⋅ b [ y ] \sum_{x \otimes y} a[x]\cdot b[y] xya[x]b[y]
其中 ⊗ \otimes 为二元运算 a n d , o r , xor and ,or,\text{xor} and,or,xor之一,即位运算卷积.

暴力显然是 O ( n 2 ) O(n^2) O(n2),我们可不可以用类似 F F T FFT FFT的思想,把 a , b a,b a,b转化为 f w t [ a ] , f w t [ b ] fwt[a],fwt[b] fwt[a],fwt[b](转点值),然后令
f w t [ c ] = f w t [ a ] ⋅ f w t [ b ] fwt[c]=fwt[a]\cdot fwt[b] fwt[c]=fwt[a]fwt[b],然后再对 f w t [ c ] fwt[c] fwt[c]进行求逆呢?(由点值求系数) 答案是肯定的!

这样的正逆变换 称为 快速沃尔什变换.

o r or or

f w t [ a ] [ i ] = ∑ i ∣ j = i a [ j ] fwt[a][i]=\sum_{i|j=i} a[j] fwt[a][i]=ij=ia[j].
我们把每个二进制看做一维的话,就是一个高维前缀和啦~~

a 0 , a 1 a_0,a_1 a0,a1表示 a a a前后长度为 n / 2 n/2 n/2的系数子序列,令 a 0 + a 1 a_0+a_1 a0+a1表示对应位置相加, merge \text{merge} merge表示序列相接,则有.

f w t [ a ] = merge ( f w t [ a 0 ] , f w t [ a 0 + a 1 ] ) fwt[a]=\text{merge}(fwt[a_0],fwt[a_0+a_1]) fwt[a]=merge(fwt[a0],fwt[a0+a1]).

void fwt_or(int *f) {
    for(int k=1;k<n;k*=2)//维度
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                (f[i+j+k] += f[i+j]) %= mod;
}

现在需要证明的是
f w t [ c ] [ i ] = f w t [ a ] [ i ] ⋅ f w t [ b ] [ i ] fwt[c][i]=fwt[a][i]\cdot fwt[b][i] fwt[c][i]=fwt[a][i]fwt[b][i]
f w t [ a ] [ i ] ⋅ f w t [ b ] [ i ] = ( ∑ i ∣ j = i a [ j ] ) ⋅ ( ∑ k ∣ i = i b [ k ] ) fwt[a][i]\cdot fwt[b][i]=\left(\sum_{i|j=i} a[j] \right) \cdot \left( \sum_{k|i=i} b[k] \right) fwt[a][i]fwt[b][i]=ij=ia[j]ki=ib[k]
因 为 i ∣ j = i , k ∣ i = i → ( j ∣ k ) ∣ i = i 因为i|j=i,k|i=i\rightarrow (j|k)|i=i ij=i,ki=i(jk)i=i
f w t [ a ] [ i ] ⋅ f w t [ b ] [ i ] = ∑ ( j ∣ k ) ∣ i a [ j ] b [ k ] = ∑ j ∣ i = i c [ j ] = f w t [ c ] [ i ] fwt[a][i]\cdot fwt[b][i]=\sum_{(j|k)|i} a[j] b[k]=\sum_{j|i=i}c[j]=fwt[c][i] fwt[a][i]fwt[b][i]=(jk)ia[j]b[k]=ji=ic[j]=fwt[c][i]

逆变换的时候,只要把正变换的影响消去即可.
a = I F W T ( f w t [ a ] ) = merge ( I F W T ( f w t [ a 0 ] ) , I F W T ( f w t [ a 1 ] − f w t [ a 0 ] ) ) a=IFWT(fwt[a])=\text{merge}(IFWT(fwt[a_0]),IFWT(fwt[a_1]-fwt[a_0])) a=IFWT(fwt[a])=merge(IFWT(fwt[a0]),IFWT(fwt[a1]fwt[a0])).

void ifwt_or(int *f) {
    for(int k=n/2;k;k/=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                f[i+j+k] = (f[i+j+k]-f[i+j]+mod)%mod;
}

你可能觉得这样的话,上面的 k k k必须从 n / 2 n/2 n/2开始 f o r for for,其实从 k = 1 k=1 k=1开始结果是一样的.
因为你把高低位互换不影响变换的正确性,这个东西在后面都适用,所以两个代码可以合并.

void fwt_or(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                add(f[i+j+k],f[i+j]*x%mod); 
}

a n d and and

同理,设 f w t [ a ] [ i ] = ∑ j & i = i a [ j ] fwt[a][i]=\sum_{j\&i=i} a[j] fwt[a][i]=j&i=ia[j].
因为 j & i = i , k & i = i → ( j & k ) & i = i j\& i=i,k\& i=i\rightarrow (j\&k) \& i=i j&i=i,k&i=i(j&k)&i=i,所以同理可证转点值后相乘的结果正确.

正逆变化类似.

void fwt_and(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                add(f[i+j],f[i+j+k]*x%mod);
}

xor \text{xor} xor

c n t ( i ) cnt(i) cnt(i)表示 i i i 二进制下有多少个1.
定义 x ⊗ y = c n t ( x & y ) m o d    2 x\otimes y=cnt(x\&y)\mod 2 xy=cnt(x&y)mod2.
f w t [ a ] [ i ] = ∑ i ⊗ j = 0 a [ j ] − ∑ i ⊗ j = 1 a [ j ] = ∑ a [ j ] ∗ ( − 1 ) i ⊗ j fwt[a][i]=\sum_{i \otimes j=0} a[j]-\sum_{i\otimes j=1} a[j]=\sum a[j]*(-1)^{i\otimes j} fwt[a][i]=ij=0a[j]ij=1a[j]=a[j](1)ij.

性质: ( i ⊗ j )   xor   ( j ⊗ k ) = i ⊗ ( j   xor   k ) (i\otimes j) ~~\text{xor} ~~ (j\otimes k)=i\otimes(j~~\text{xor}~~k) (ij)  xor  (jk)=i(j  xor  k).
证明: xor \text{xor} xor为不进位加法,所以我们实际上是证明 c n t ( i & j ) + c n t ( i & k ) ≡ c n t ( i & ( j xor k ) ) ( m o d    2 ) cnt(i\&j)+cnt(i\&k)\equiv cnt(i\&(j\text{xor} k))(\mod 2) cnt(i&j)+cnt(i&k)cnt(i&(jxork))(mod2)

我们对一个位的所有情况进行证明,那么总体就一定满足.

i i i j j j k k k
0 - -
1 0 1
1 1 1

i = 0 i=0 i=0显然都是0.
i = 1 , j + k = 1 i=1,j+k=1 i=1,j+k=1,则显然成立.
i = 1 , j + k = 2 i=1,j+k=2 i=1,j+k=2,左边为2,右边为0,成立!

综上:我们通过穷举证明了每一位的情况,也就是证明了所有的情况都满足.

现在依然是要证明:
f w t [ c ] [ i ] = f w t [ a ] [ i ] ⋅ f w t [ b ] [ i ] fwt[c][i]=fwt[a][i]\cdot fwt[b][i] fwt[c][i]=fwt[a][i]fwt[b][i]
f w t [ a ] [ i ] ⋅ f w t [ b ] [ i ] = ( ∑ ( − 1 ) i ⊗ j a [ j ] ) ⋅ ( ∑ ( − 1 ) i ⊗ k b [ k ] ) = ∑ ( − 1 ) i ⊗ j   xor   i ⊗ k a [ j ] ⋅ b [ k ] fwt[a][i]\cdot fwt[b][i]=\left( \sum (-1)^{i\otimes j} a[j]\right) \cdot \left( \sum (-1)^{i\otimes k} b[k]\right) =\sum (-1)^{i\otimes j~~\text{xor}~ ~i \otimes k}a[j]\cdot b[k] fwt[a][i]fwt[b][i]=((1)ija[j])((1)ikb[k])=(1)ij  xor  ika[j]b[k]
∑ ( − 1 ) i ⊗ j   xor   i ⊗ k a [ j ] ⋅ b [ k ] = ∑ ( − 1 ) i ⊗ ( j  xor  k ) a [ j ] ⋅ b [ k ] = ∑ ( − 1 ) i ⊗ j c [ j ] = f w t [ c ] [ i ] \sum (-1)^{i\otimes j~~\text{xor}~ ~i \otimes k}a[j]\cdot b[k]=\sum (-1)^{i\otimes(j ~\text{xor} ~k)} a[j] \cdot b[k]=\sum (-1)^{i\otimes j} c[j]=fwt[c][i] (1)ij  xor  ika[j]b[k]=(1)i(j xor k)a[j]b[k]=(1)ijc[j]=fwt[c][i]





正变换:
f w t [ a ] = merge ( f w t [ a 0 ] + f w t [ a 1 ] , f w t [ a 0 ] − f w t [ a 1 ] ) fwt[a]=\text{merge} (fwt[a_0]+fwt[a_1],fwt[a_0]-fwt[a_1]) fwt[a]=merge(fwt[a0]+fwt[a1],fwt[a0]fwt[a1]).

证明:在求解小规模数据时 a 1 a_1 a1时不知道自己最高位位1的,
此时, i ∈ a 0 , j ∈ a 1 , c n t ( i & j ) = c n t ( ( i + n / 2 ) & j ) = f w t [ a 1 ] [ i ] i\in a_0,j\in a_1,cnt(i\&j)=cnt((i+n/2)\&j)=fwt[a_1][i] ia0,ja1,cnt(i&j)=cnt((i+n/2)&j)=fwt[a1][i].
i ∈ a 1 , j ∈ a 1 , c n t ( ( i + n / 2 ) & ( j + n / 2 ) ) = c n t ( i & j ) + 1 i\in a_1,j\in a_1,cnt((i+n/2)\& (j+n/2))=cnt(i\&j)+1 ia1,ja1,cnt((i+n/2)&(j+n/2))=cnt(i&j)+1,所以右边合并的时候 f w t [ a 1 ] fwt[a_1] fwt[a1]的符号改变.

逆变换:
I F W T ( a ) = merge ( I F W T ( a 0 + a 1 2 , I F W T ( a 0 − a 1 2 ) ) IFWT(a)=\text{merge} (IFWT(\dfrac{a_0+a_1}2,IFWT(\dfrac {a_0-a_1} 2)) IFWT(a)=merge(IFWT(2a0+a1,IFWT(2a0a1)).

void fwt_xor(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)  {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j] = f[i+j]*x%mod;
                f[i+j+k] = f[i+j+k]*x%mod;
            }
}

模板题

板子:

int n,a[N],b[N],A[N],B[N];
void add(int &x,int y) {x+=y; if(x>=mod)  x-= mod;}
void upd(int &x) {x+=x>>31&mod;}
void del(int &x,int y) {upd(x-=y);}

void fwt_or(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)
                add(f[i+j+k],f[i+j]*x%mod); 
}

void fwt_and(int *f,ll x) {
    if(x==-1) x+=mod;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                add(f[i+j],f[i+j+k]*x%mod);
}

void fwt_xor(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++)  {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j] = f[i+j]*x%mod;
                f[i+j+k] = f[i+j+k]*x%mod;
            }
}

void solve(void (*fwt)(int*f,ll x)) {
    for(int i=0;i<n;i++) a[i]=A[i],b[i]=B[i];
    fwt(a,1); fwt(b,1);
    for(int i=0;i<n;i++) a[i]=(ll)a[i]*b[i]%mod;
    fwt(a,-1);
    for(int i=0;i<n;i++) pr1(a[i]);
    puts("");
}

int main() {
    qr(n); n=1<<n;
    for(int i=0;i<n;i++) qr(A[i]);
    for(int i=0;i<n;i++) qr(B[i]);
    solve(fwt_or);
    solve(fwt_and);
    solve(fwt_xor);
    return 0;
}

例题

bzoj #4589. Hard Nim

有两个神在玩nim游戏,有 n n n堆石子,每堆石子的大小为 ≤ m \le m m的质数,求先手必败的方案数.
m ≤ 50000 , n ≤ 1 0 9 m\le 50000,n\le 10^9 m50000,n109.

定义一个多项式 f , f [ p ] = [ p ≤ m ∩ p ∈ p r i m e ] f,f[p]=[p\le m\cap p\in prime] f,f[p]=[pmpprime]
定义 f × g f\times g f×g表示 f , g f,g f,g对应位置相乘,区别于 f ∗ g f*g fg(表示卷积).
则我们要求的就是 I F W T ( f w t [ f ] n ) [ 0 ] IFWT(fwt[f]^n)[0] IFWT(fwt[f]n)[0].
因为每乘一次就相当于作一次异或卷积,即多一堆石子,所以正确.

复杂度为 O ( m ( log ⁡ m + log ⁡ n ) ) O(m(\log m+\log n)) O(m(logm+logn)).


void add(int &x,int y) {x+=y; if(x>=mod) x-=mod;}
void upd(int &x) {x+=x>>31&mod;}
void del(int &x,int y) {upd(x -= y);}


int prime[N],tot; bool v[N];
void get(int x) {
    for(int i=2;i<=x;i++) {
        if(!v[i]) prime[++tot]=i;
        for(int j=1,k;(k=i*prime[j])<=x;j++) {
            v[k]=1;
            if(i%prime[j]==0) break;
        }
    }
}

int t;
void fwt(int *f,ll x) {
    if(x==-1) x=(mod+1)/2;
    for(int k=1;k<t;k*=2)
        for(int i=0;i<t;i+=2*k)
            for(int j=0;j<k;j++) {
                int u=f[i+j],v=f[i+j+k];
                add(f[i+j],v); del(f[i+j+k]=u,v);
                f[i+j]=f[i+j]*x%mod;
                f[i+j+k]=f[i+j+k]*x%mod;
            }
}

int n,m,f[N];
ll power(ll a,ll b=n) {
    ll c=1;
    while(b) {
        if(b&1) c=c*a%mod;
        b /= 2; a=a*a%mod;
    }
    return c;
}

int main() {

    get(N-1);
    while(~scanf("%d%d",&n,&m)) {
        for(t=1;t<=m;t*=2);
        memset(f,0,sizeof f);
        for(int i=1;i<=tot&&prime[i]<=m;i++) f[prime[i]]=1;;
        fwt(f,1);
        for(int i=0;i<t;i++) f[i]=power(f[i]);
        fwt(f,-1); pr2(f[0]);
    }
    return 0;
}

P3175 [HAOI2015]按位或

定义 m i n ( T ) min(T) min(T)为取到 T T T集合中任意一位的最小时间.
m i n ( T ) = 1 ∑ S ∩ T ≠ ∅ p S = 1 1 − f w t [ T ‾ ] min(T)=\dfrac 1{\sum_{S\cap T\ne \varnothing} p_S}=\dfrac 1 {1-fwt[\overline T]} min(T)=ST=pS1=1fwt[T]1

#include
using namespace std;
const int N=(1<<20)|10;
const double eps=1e-9;

int n,g[N];
double f[N],ans;

void fwt(double *f) {
    for(int k=1;k<n;k*=2)
        for(int i=0;i<n;i+=2*k)
            for(int j=0;j<k;j++) 
                f[i+j+k] += f[i+j];
}

int main() {
    scanf("%d",&n); n=1<<n;
    for(int i=0;i<n;i++) scanf("%lf",&f[i]);
    fwt(f); g[0]=-1;
    for(int i=1;i<n;i++) {
        g[i]=-g[i&(i-1)];
        if(fabs(f[i^(n-1)]-1)<eps) 
            {puts("INF"); return 0;}
        ans += g[i]/(1-f[i^(n-1)]);
    }
    printf("%.10lf\n",ans); 
    return 0;
}

你可能感兴趣的:(#,多项式)