FWT

线性变换>->

类比FFT

对于这类东西,我们考虑 t f ( A ) t f ( B ) = t f ( A ∗ B ) tf(A)tf(B)=tf(A*B) tf(A)tf(B)=tf(AB),其中*为某二元运算,tf为线性变换,设 C = A ∗ B C=A*B C=AB
形象的我们可以把tf认为 t f ( A ) i = ∑ j = 0 n A j f ( n , i , j ) tf(A)_{i}=\sum_{j=0}^{n}A_{j}f(n,i,j) tf(A)i=j=0nAjf(n,i,j),其中f(i,j)是一个函数
所以,我们有
t f ( A ) i t f ( B ) i = t f ( A ∗ B ) i tf(A)_{i}tf(B)_{i}=tf(A*B)_{i} tf(A)itf(B)i=tf(AB)i
( ∑ j = 0 n A j f ( n , i , j ) ) ( ∑ j = 0 n B j f ( n , i , j ) ) = ∑ j = 0 n C j f ( n , i , j ) (\sum_{j=0}^{n}A_{j}f(n,i,j))(\sum_{j=0}^{n}B_{j}f(n,i,j))=\sum_{j=0}^{n}C_{j}f(n,i,j) (j=0nAjf(n,i,j))(j=0nBjf(n,i,j))=j=0nCjf(n,i,j)
f ( n , i , j ) f ( n , i , k ) = f ( n , i , j ∗ k ) f(n,i,j)f(n,i,k)=f(n,i,j*k) f(n,i,j)f(n,i,k)=f(n,i,jk)
如果是FFT那么我们有 f ( n , i , j ) = ( w n i ) j f(n,i,j)=(w_{n}^{i})^{j} f(n,i,j)=(wni)j,易证明 f ( n , i , j ) f ( n , i , k ) = f ( n , i , j ∗ k ) f(n,i,j)f(n,i,k)=f(n,i,j*k) f(n,i,j)f(n,i,k)=f(n,i,jk)
现在我们需要位运算卷积

or卷积

我们定义 f ( n , i , j ) = [ ( j ∣ i ) = = i ] f(n,i,j)=[(j|i)==i] f(n,i,j)=[(ji)==i],显然满足 f ( n , i , j ) f ( n , i , k ) = f ( n , i , j ∗ k ) f(n,i,j)f(n,i,k)=f(n,i,j*k) f(n,i,j)f(n,i,k)=f(n,i,jk),所以我们有 t f ( A ) i = ∑ j = 0 n A j [ ( j ∣ i ) = = i ] tf(A)_{i}=\sum_{j=0}^{n}A_{j}[(j|i)==i] tf(A)i=j=0nAj[(ji)==i]
接下来我们考虑怎么计算tf(A),貌似可以直接子集和…
我们记 A 0 , A 1 A_{0},A_{1} A0,A1表示A的低 2 n − 1 2^{n-1} 2n1位与高 2 n − 1 2^{n-1} 2n1位,这时我们发现
t f ( A ) = ( t f ( A 0 ) , t f ( A 1 ) + t f ( A 0 ) ) tf(A)=(tf(A_{0}),tf(A_{1})+tf(A_{0})) tf(A)=(tf(A0),tf(A1)+tf(A0))
定义 I t f 为 t f Itf为tf Itftf的逆变换
I f ( n , i , j ) = [ j & i = = i ] ( − 1 ) ( c o u n t ( j ) − c o u n t ( i ) ) If(n,i,j)=[j\&i==i](-1)^{(count(j)-count(i))} If(n,i,j)=[j&i==i](1)(count(j)count(i))
I t f ( A ) = ( I t f ( A 0 ) , I t f ( A 1 ) − I t f ( A 0 ) ) Itf(A)=(Itf(A_{0}),Itf(A_{1})-Itf(A_{0})) Itf(A)=(Itf(A0),Itf(A1)Itf(A0))

and卷积

我们定义 f ( n , i , j ) = [ j & i = = i ] f(n,i,j)=[j\&i==i] f(n,i,j)=[j&i==i],显然满足 f ( n , i , j ) f ( n , i , k ) = f ( n , i , j ∗ k ) f(n,i,j)f(n,i,k)=f(n,i,j*k) f(n,i,j)f(n,i,k)=f(n,i,jk),所以 t f ( A ) i = ∑ j = 0 n A j [ j & i = = i ] tf(A)_{i}=\sum_{j=0}^{n}A_{j}[j\&i==i] tf(A)i=j=0nAj[j&i==i]
计算 t f ( A ) tf(A) tf(A),超集和
t f ( A ) = ( t f ( A 0 ) + t f ( A 1 ) , t f ( A 1 ) ) tf(A)=(tf(A_{0})+tf(A_{1}),tf(A_{1})) tf(A)=(tf(A0)+tf(A1),tf(A1))
I f ( n , i , j ) = [ j ∣ i = = i ] ( − 1 ) c o u n t ( i ) − c o u n t ( j ) If(n,i,j)=[j|i==i](-1)^{count(i)-count(j)} If(n,i,j)=[ji==i](1)count(i)count(j)
I t f ( A ) = ( I t f ( A 0 ) − I t f ( A 1 ) , I f t ( A 1 ) ) Itf(A)=(Itf(A_{0})-Itf(A_{1})_,Ift(A_{1})) Itf(A)=(Itf(A0)Itf(A1),Ift(A1))

xor卷积

我们定义 f ( n , i , j ) = ( − 1 ) c o u n t ( j & i ) f(n,i,j)=(-1)^{count(j\&i)} f(n,i,j)=(1)count(j&i),显然满足 f ( n , i , j ) f ( n , i , k ) = f ( n , i , j ∗ k ) f(n,i,j)f(n,i,k)=f(n,i,j*k) f(n,i,j)f(n,i,k)=f(n,i,jk)
所以 t f ( A ) i = ∑ j = 0 n A j ( − 1 ) c o u n t ( j & i ) tf(A)_{i}=\sum_{j=0}^{n}A_{j}(-1)^{count(j\&i)} tf(A)i=j=0nAj(1)count(j&i)
t f ( A ) = ( t f ( A 0 ) + t f ( A 1 ) , t f ( A 0 ) − t f ( A 1 ) ) tf(A)=(tf(A_{0})+tf(A_{1}),tf(A_{0})-tf(A_{1})) tf(A)=(tf(A0)+tf(A1),tf(A0)tf(A1))
I f ( n , i , j ) = 1 n ∑ j = 0 n A j ( − 1 ) c o u n t ( j & i ) If(n,i,j)=\frac{1}{n}\sum_{j=0}^{n}A_{j}(-1)^{count(j\&i)} If(n,i,j)=n1j=0nAj(1)count(j&i)
I t f ( A ) = ( t f ( A 0 ) + t f ( A 1 ) 2 , t f ( A 0 ) − t f ( A 1 ) 2 ) Itf(A)=(\frac{tf(A_{0})+tf(A_{1})}{2},\frac{tf(A_{0})-tf(A_{1})}{2}) Itf(A)=(2tf(A0)+tf(A1),2tf(A0)tf(A1))
注:以上用线性变换来解释fwt,貌似有点不太行
具体可以看2015年论文,集合幂级数貌似靠谱一点
【模板】快速沃尔什变换

#include
#define ll long long
#define pb(x) push_back(x)
using namespace std;
const int mod=998244353;
typedef vector<int> poly;
poly a_or,b_or,a_and,b_and,a_xor,b_xor;
int n,x,inv2;
int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int mul(int x,int y){return (ll)x*y%mod;}
int ksm(int x,int y){
	int ans=1;
	for (;y;y>>=1,x=mul(x,x)) if (y&1) ans=mul(ans,x);
	return ans;
}
inline poly operator*(poly a,poly b){
	for (int i=0;i<a.size();i++) a[i]=mul(a[i],b[i]);
	return a;
}
void Fwt_or(poly &a,int t){
	for (int i=1;i<(1<<n);i<<=1)
	 for (int j=0,p=(i<<1);j<(1<<n);j+=p)
	 	for (int k=j;k<j+i;k++) if (t==1) a[k+i]=(a[k+i]+a[k])%mod;
	 	else a[k+i]=dec(a[k+i],a[k]);
}
void Fwt_and(poly &a,int t){
	for (int i=1;i<(1<<n);i<<=1)
	 for (int j=0,p=(i<<1);j<(1<<n);j+=p)
	  for (int k=j;k<j+i;k++) if (t==1) a[k]=add(a[k+i],a[k]);
	  else a[k]=dec(a[k],a[k+i]);
}
void Fwt_xor(poly &a,int t){
	for (int i=1;i<(1<<n);i<<=1)
	 for (int j=0,p=(i<<1);j<(1<<n);j+=p)
	  for (int k=j;k<j+i;k++) {
	  	int x=a[k],y=a[k+i];
	  	if (t==1) a[k]=add(x,y),a[k+i]=dec(x,y);
	  	else a[k]=mul(add(x,y),inv2),a[k+i]=mul(dec(x,y),inv2);
	  }
}
int main(){
	scanf("%d",&n);
	for (int i=0;i<(1<<n);i++) scanf("%d",&x),a_or.pb(x),a_and.pb(x),a_xor.pb(x);
	for (int i=0;i<(1<<n);i++) scanf("%d",&x),b_or.pb(x),b_and.pb(x),b_xor.pb(x);
	inv2=ksm(2,mod-2);
	Fwt_or(a_or,1); Fwt_or(b_or,1); a_or=a_or*b_or; Fwt_or(a_or,-1);
	Fwt_and(a_and,1); Fwt_and(b_and,1); a_and=a_and*b_and; Fwt_and(a_and,-1);
	Fwt_xor(a_xor,1); Fwt_xor(b_xor,1); a_xor=a_xor*b_xor; Fwt_xor(a_xor,-1);
	for (int i=0;i<(1<<n);i++) printf("%d ",a_or[i]);
	printf("\n");
	for (int i=0;i<(1<<n);i++) printf("%d ",a_and[i]);
	printf("\n");
	for (int i=0;i<(1<<n);i++) printf("%d ",a_xor[i]);
}

【UNR #2】黎明前的巧克力
题意:
你现在有一个数集T,要从中选出一个子集s(s不为空),如果xor为0则对答案的贡献为2|s|否则不对答案产生贡献
解析:就是求 ∏ i = 1 n ( 1 + 2 x a i ) \prod_{i=1}^{n}(1+2x^{a_{i}}) i=1n(1+2xai)的第0项系数,这里的 ∏ \prod 为异或卷积
我们可以考虑把每个 ( 1 + 2 x a i ) (1+2x^{a_{i}}) (1+2xai)Fwt一下,再相乘,再Fwt回去,这样时间复杂度为 O ( n 2 l o g n ) O(n^2logn) O(n2logn)
我们考虑对 ( 1 + 2 x a i ) (1+2x^{a_{i}}) (1+2xai)Fwt本质上每位只会是3或-1,那么最后我们只需要求出每一位上有多少个3,有多少个-1即可.

#include
#define ll long long
using namespace std;
const int M=20;
const int mod=998244353;
int f[1<<M],fac[1<<M];
int n,x,inv2;
int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x-y<0?x-y+mod:x-y;}
int mul(int x,int y){return (ll)x*y%mod;}
int ksm(int x,int y){
	int ans=1;
	for (;y;y>>=1,x=mul(x,x)) if (y&1) ans=mul(ans,x);
	return ans;
}
void Fwt(int *a,int opt){
	for (int i=1;i<(1<<M);i<<=1)
	 for (int j=0,p=(i<<1);j<(1<<M);j+=p)
	  for (int k=j;k<j+i;k++) {
	  	int x=a[k],y=a[k+i];
	  	if (opt==1) a[k]=x+y,a[k+i]=x-y;
	  	else a[k]=mul(add(x,y),inv2),a[k+i]=mul(dec(x,y),inv2);
	  }
}
signed main(){
	scanf("%d",&n);
	for (int i=1;i<=n;i++) {
		scanf("%d",&x); f[x]++;
	}
	Fwt(f,1); 
	fac[0]=1; for (int i=1;i<(1<<M);i++) fac[i]=mul(fac[i-1],3);
	for (int i=0;i<(1<<M);i++) {
		int x=(n+f[i])/2,y=n-x;
		f[i]=fac[x]; if (y&1) f[i]=(mod-f[i])%mod;
	} 
	inv2=ksm(2,mod-2);
	Fwt(f,-1);
	printf("%d\n",dec(f[0],1));
}

你可能感兴趣的:(fwt)