快速沃尔什变换,简写FWT,变换肯定会想到FFT,和其相似,FWT同样是用于求解某些特定的卷积的。
FFT求解的问题一般可以化成如下形式
C k = Σ i + j = k A i ∗ B j C_k=\Sigma_{i+j=k}A_i*B_j Ck=Σi+j=kAi∗Bj
但是会发现,如果求解的 C k C_k Ck中 k k k的条件变成 i ∣ j = k i|j=k i∣j=k , i & j i \And j i&j 或者 i ⨁ j i \bigoplus j i⨁j就没有办法计算,所以我们需要使用FWT来解决这类问题。
关于FWT的证明就不写了,我们可以通过其结论可知
F W T ( A ) = ( F W T ( A 0 ) , F W T ( A 0 + A 1 ) ) FWT(A)=(FWT(A_0),FWT(A_0+A_1)) FWT(A)=(FWT(A0),FWT(A0+A1))
中间的 A 0 , A 1 A_0,A_1 A0,A1分别表示 A A A的前 2 n − 1 2^{n-1} 2n−1和后 2 n − 1 2^{n-1} 2n−1部分,其实这个结论也可以通过感性认识一下,因为进行的是或操作,显然如果将 A 0 A_0 A0中的某一项和 A 1 A_1 A1中的某一项进行计算,得到的结果只能够影响到 F W T ( A 1 ) FWT(A_1) FWT(A1)的部分,对于 F W T ( A 0 ) FWT(A_0) FWT(A0)没有贡献,所以可以得到这个结论。
而根据
F W T ( A + B ) = F W T ( A ) ∗ F W T ( B ) FWT(A+B)=FWT(A)*FWT(B) FWT(A+B)=FWT(A)∗FWT(B)
我们就可以得到
F W T ( A ∣ B ) = F W T ( A ) ∗ F W T ( B ) FWT(A|B)=FWT(A)*FWT(B) FWT(A∣B)=FWT(A)∗FWT(B)的结论,也就证明了FWT的正确性。
F W T ( A ) = ( F W T ( A 0 + A 1 ) , F W T ( A 1 ) ) FWT(A)=(FWT(A_0+A_1),FWT(A_1)) FWT(A)=(FWT(A0+A1),FWT(A1))
F W T ( A ) = ( F W T ( A 0 + A 1 ) , F W T ( A 0 − A 1 ) ) FWT(A)=(FWT(A_0+A_1),FWT(A_0-A_1)) FWT(A)=(FWT(A0+A1),FWT(A0−A1))
这两个部分同样可以经过操作证明FWT的正确性。
void FWT_or(int *a,int opt)
{
for(int i=1;i<N;i<<=1)
for(int p=i<<1,j=0;j<N;j+=p)
for(int k=0;k<i;++k)
if(opt==1)a[i+j+k]=(a[j+k]+a[i+j+k])%MOD;
else a[i+j+k]=(a[i+j+k]+MOD-a[j+k])%MOD;
}
void FWT_and(int *a,int opt)
{
for(int i=1;i<N;i<<=1)
for(int p=i<<1,j=0;j<N;j+=p)
for(int k=0;k<i;++k)
if(opt==1)a[j+k]=(a[j+k]+a[i+j+k])%MOD;
else a[j+k]=(a[j+k]+MOD-a[i+j+k])%MOD;
}
void FWT_xor(int *a,int opt)
{
for(int i=1;i<N;i<<=1)
for(int p=i<<1,j=0;j<N;j+=p)
for(int k=0;k<i;++k)
{
int X=a[j+k],Y=a[i+j+k];
a[j+k]=(X+Y)%MOD;a[i+j+k]=(X+MOD-Y)%MOD;
if(opt==-1)a[j+k]=1ll*a[j+k]*inv2%MOD,a[i+j+k]=1ll*a[i+j+k]*inv2%MOD;
}
}
代码和思路来源于https://www.cnblogs.com/cjyyb/p/9065615.html
简单来讲就是给你一棵树,每个点有权值 V i V_i Vi,定义树上联通块的权值 S k = ⨁ i = 1.. s i z e ( k ) V i S_k = \bigoplus_{i=1..size(k)}V_i Sk=⨁i=1..size(k)Vi,求问有多少个联通块的权值为 i i i。
首先不难看出这个问题是一道树形dp,记 f [ i ] [ j ] f[i][j] f[i][j]表示以 i i i为根的子树中,联通块值为 j j j的数目,如果直接暴力合并子树的话,复杂度为 O ( N 3 ) O(N^3) O(N3),如果在合并异或的时候,我们使用FWT,复杂度就会降成 O ( N 2 l o g 2 N ) O(N^2log_{2}N) O(N2log2N)。
#include
#define rep( i , l , r ) for( int i = (l) ; i <= (r) ; ++i )
#define per( i , r , l ) for( int i = (r) ; i >= (l) ; --i )
#define erep( i , u ) for( int i = head[(u)] ; ~i ; i = e[i].nxt )
using namespace std;
const int maxn = 1111 , maxm = 1111 , MOD = 1e9 + 7 , inv2 = 5e8 + 4;
int head[maxn] , _t = 0;
struct edge{
int v , nxt;
}e[maxn << 1];
inline void addedge( int u , int v ){
e[_t].v = v , e[_t].nxt = head[u] , head[u] = _t++;
e[_t].v = u , e[_t].nxt = head[v] , head[v] = _t++;
}
int f[maxn][maxm] , ans[maxm];
void FWT( int *a , int N , int inv ){
for( int i = 1 ; i < N ; i <<= 1 )
for( int p = i << 1 , j = 0 ; j < N ; j += p )
for( int k = 0 ; k < i ; ++k ){
int X = a[j + k] , Y = a[i + j + k];
a[j + k] = (X + Y) % MOD;
a[i + j + k] = (X + MOD - Y) % MOD;
if( inv == -1 ){
a[j + k] = 1ll * a[j + k] * inv2 % MOD;
a[i + j + k] = 1ll * a[i + j + k] * inv2 % MOD;
}
}
}
int N , M;
void dfs( int u , int fa ){
FWT( f[u] , M , 1 );
erep( i , u ){
int v = e[i].v;
if( v == fa ) continue;
dfs( v , u );
for( int i = 0 ; i < M ; ++i ) f[u][i] = (f[u][i] * f[v][i]) % MOD;
}
FWT( f[u] , M , -1 );
if( ++f[u][0] == MOD ) f[u][0] -= MOD;
FWT( f[u] , M , 1 );
}
inline int _read(){
int x = 0 , f = 1;
char ch = getchar();
while( ch > '9' || ch < '0' ){
if( ch == '-' ) f = -1;
ch = getchar();
}
while( '0' <= ch && ch <= '9' ){
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar();
}
return x * f;
}
int main(){
int T = _read();
while( T-- ){
N = _read() , M = _read();
memset( f , 0 , sizeof f );
memset( head , 0xff , sizeof head );
memset( ans , 0 , sizeof ans );
rep( i , 1 , N ) f[i][_read()] = 1;
_t = 0;
int u , v;
rep( i , 1 , N - 1 ){
u = _read() , v = _read();
addedge( u , v );
}
dfs( 1 , -1 );
for( int i = 1 ; i <= N ; ++i ) FWT( f[i] , M , -1 );
for( int i = 1 ; i <= N ; ++i )
if( --f[i][0] < 0 ) f[i][0] += MOD;
for( int i = 1 ; i <= N ; ++i ){
for( int j = 0 ; j < M ; ++j ){
ans[j] = (ans[j] + f[i][j]) % MOD;
}
}
for( int i = 0 ; i < M ; ++i ){
printf("%d",ans[i]) , putchar(i==M-1?'\n':' ');
}
}
return 0;
}