HDU5909 FWT加速异或卷积

简略题意:
从一个节点数为 n n 的树中选出一颗子树,这个子树的值为它所有的节点的值异或起来的结果。现在需要你输出异或值分别为 [0,m) [ 0 , m ) 的方案树。

先考虑简单的树形 DP D P
dp[u][i] d p [ u ] [ i ] 代表以u为根的子树,所有节点的值异或起来为 i i 的方案树。
那么存在转移
dp[u][i]=vson(u)dp[v][j]dp[u][i xor j] d p [ u ] [ i ] = ∑ v ∈ s o n ( u ) d p [ v ] [ j ] ∗ d p [ u ] [ i   x o r   j ] .
那么暴力的复杂度就是 O(nmm) O ( n ∗ m ∗ m )
然后上面那个异或卷积可以用FWT加速一下,复杂度 O(nmlogm) O ( n ∗ m ∗ l o g m )

#define others
#ifdef poj
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#endif // poj
#ifdef others
#include 
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
#define eps 1e-8
const double pi = acos(-1.0);

typedef long long LL;
typedef unsigned long long ULL;
void umax(int &a, int b) {
    a = max(a, b);
}
void umin(int &a, int b) {
    a = min(a, b);
}
int dcmp(double x) {
    return fabs(x) <= eps?0:(x > 0?1:-1);
}
void file() {
    freopen("data_in.txt", "r", stdin);
    freopen("data_out.txt", "w", stdout);
}

namespace solver {
    const LL mod = 1e9+7;
    const LL maxn = 1025;
    const LL rev = 500000004;
    LL t;
    LL n, m;
    LL val[maxn];
    vector G[maxn];
    LL dp[maxn][maxn];
    LL ans[maxn];
    LL Pow(LL a, LL b) {
        LL res = 1;
        while(b) {
            if(b & 1) res *= a, res%= mod;
            b >>= 1;
            a *= a;
            a %= mod;
        }
        return res;
    }
    void FWT(LL a[],LL n) {
        for(LL d=1;d1)
            for(LL m=d<<1,i=0;ifor(LL j=0;j//xor:a[i+j]=x+y,a[i+j+d]=(x-y+mod)%mod;
                    //and:a[i+j]=x+y;
                    //or:a[i+j+d]=x+y;
                }
    }

    void UFWT(LL a[],LL n) {
        for(LL d=1;d1)
            for(LL m=d<<1,i=0;ifor(LL j=0;j1LL*(x+y)*rev%mod,a[i+j+d]=(1LL*(x-y)*rev%mod+mod)%mod;
                    //xor:a[i+j]=(x+y)/2,a[i+j+d]=(x-y)/2;
                    //and:a[i+j]=x-y;
                    //or:a[i+j+d]=y-x;
                }
    }
    void solve(LL a[],LL b[],LL n) {
        FWT(a,n);
        FWT(b,n);
        for(LL i=0;i1LL*a[i]*b[i]%mod;
        UFWT(a,n);
    }

    void dfs(LL u, LL fa = -1) {
        dp[u][val[u]] = 1;
        for(auto v:G[u]) {
            if(v == fa) continue;
            dfs(v, u);
            LL tmp[1025] = {0};
            for(LL j = 0; j < m; j++) tmp[j] = dp[u][j];
//            for(LL j = 0; j < m; j++)
//                for(LL k = 0; k < m; k++)
//                    dp[u][j] += dp[v][k] * tmp[j^k];
            solve(tmp, dp[v], m);
            for(LL i = 0; i < m; i++) dp[u][i] += tmp[i];
        }
        for(LL i = 0; i < m; i++) ans[i] += dp[u][i], ans[i] %= mod;
    }
    void solve() {
//        cout<
        scanf("%lld", &t);
        while(t--) {
            for(LL i = 0; i < maxn; i++) G[i].clear();
            memset(dp, 0, sizeof dp);
            memset(ans, 0, sizeof ans);
            scanf("%lld%lld", &n, &m);
            for(LL i = 1; i <= n; i++) scanf("%lld", &val[i]);
            for(LL i = 1; i <= n - 1; i++) {
                LL x, y;
                scanf("%lld%lld", &x, &y);
                G[x].push_back(y);
                G[y].push_back(x);
            }
            dfs(1);
            for(LL i = 0; i < m; i++)
                printf("%lld%c", ans[i], i == m - 1?'\n':' ');
        }
    }
}

int main() {
//    file();
    solver::solve();
    return 0;
}

你可能感兴趣的:(FWT,DP)