HDU4670 Cube number on a tree 树分治

    人生的第一道树分治,要是早点学我南京赛就不用那么挫了,树分治的思路其实很简单,就是对子树找到一个重心(Centroid),实现重心分解,然后递归的解决分开后的树的子问题,关键是合并,当要合并跨过重心的两棵子树的时候,需要有一个接近O(n)的方法,因为f(n)=kf(n/k)+O(n)解出来才是O(nlogn).在这个题目里其实就是将第一棵子树的集合里的每个元素,判下有没符合条件的,有就加上,然后将子树集合压进大集合,然后继续搞第二棵乃至第n棵.我的过程用了map,合并是nlogn的所以代码速度颇慢,大概6s,题目时限10s,可以改成hash应该会快许多,毕竟用map实在太慢,用vector也可以,具体可以参见挑战程序设计竞赛代码.下面的代码查找重心用了挑战的代码.

#pragma comment(linker, "/STACK:102400000,102400000")

#include<iostream>

#include<cstring>

#include<string>

#include<cstdio>

#include<algorithm>

#include<map>

#include<vector>

#define maxv 50000

#define ll long long

using namespace std;



int n,k;

vector<int> G[maxv+50];

ll val[maxv+50];

ll prime[maxv+50];

ll convert_three(ll v)

{

    ll bas=1;ll res=0;

    for(int i=0;i<k;++i){

        int num=0;

        while(v%prime[i]==0){

            v/=prime[i];

            num++;

        }

        num%=3;res+=num*bas;

        bas*=3;

    }

    return res;

}



ll xor(ll x,ll y)

{

    ll res=0;ll bas=1;

    for(int i=0;i<k;++i){

        res+=((x%3)+(y%3))%3*bas;

        x/=3;y/=3;

        bas*=3;

    }

    return res;

}



ll inv(ll x)

{

    ll res=0;ll bas=1;

    for(int i=0;i<k;++i){

        res+=((3-(x%3))%3)*bas;

        x/=3;

        bas*=3;

    }

    return res;

}



void print(ll x){

    while(x){

        cout<<x%3;

        x/=3;

    }

    cout<<endl;

}



bool centroid[maxv+50];

int ssize[maxv+50];

int ans;



map<ll,int> sta;

map<ll,int>::iterator it;

int compute_ssize(int v,int p)

{

    int c=1;

    for(int i=0;i<G[v].size();++i){

        int w=G[v][i];

        if(w==p||centroid[w]) continue;

        c+=compute_ssize(G[v][i],v);

    }

    ssize[v]=c;

    return c;

}



pair<int,int> search_centroid(int v,int p,int t)

{

    pair<int,int> res=make_pair(INT_MAX,-1);

    int s=1,m=0;

    for(int i=0;i<G[v].size();++i){

        int w=G[v][i];

        if(w==p||centroid[w]) continue;

        res=min(res,search_centroid(w,v,t));

        m=max(m,ssize[w]);

        s+=ssize[w];

    }

    m=max(m,t-s);

    res=min(res,make_pair(m,v));

    return res;

}



void enumerate_mul(int v,int p,ll d,map<ll,int> &ds)

{

    if(ds.count(d)) ds[d]++;

    else ds[d]=1;

    for(int i=0;i<G[v].size();++i){

        int w=G[v][i];

        if(w==p||centroid[w]) continue;

        enumerate_mul(w,v,xor(d,val[w]),ds);

    }

}



void solve(int v)

{

    compute_ssize(v,-1);

    int s=search_centroid(v,-1,ssize[v]).second;

    centroid[s]=true;

    for(int i=0;i<G[s].size();++i){

        if(centroid[G[s][i]]) continue;

        solve(G[s][i]);

    }

    sta.clear();

    sta[val[s]]=1;map<ll,int> tds;

    for(int i=0;i<G[s].size();++i){

        if(centroid[G[s][i]]) continue;

        tds.clear();

        enumerate_mul(G[s][i],s,val[G[s][i]],tds);

        it=tds.begin();

        while(it!=tds.end()){

            ll rev=inv((*it).first);

            if(sta.count(rev)){

                ans+=sta[rev]*(*it).second;

            }

            ++it;

        }

        it=tds.begin();

        while(it!=tds.end()){

            ll  vv=xor((*it).first,val[s]);

            if(sta.count(vv)){

                sta[vv]+=(*it).second;

            }

            else{

                sta[vv]=(*it).second;

            }

            ++it;

        }

    }

    centroid[s]=false;

}



int main()

{

    while(cin>>n>>k){

        ans=0;

        for(int i=0;i<k;++i){

            scanf("%I64d",&prime[i]);

        }

        G[0].clear();

        for(int i=1;i<=n;++i){

            scanf("%I64d",&val[i]);

            val[i]=convert_three(val[i]);

            if(val[i]==0) ans++;

            //print(val[i]);

            G[i].clear();

        }

        int u,v;

        for(int i=0;i<n-1;++i){

            scanf("%d%d",&u,&v);

            G[u].push_back(v);

            G[v].push_back(u);

        }

        memset(centroid,0,sizeof(centroid));

        solve(1);

        printf("%d\n",ans);

    }

    return 0;

}

 

你可能感兴趣的:(number)