HDU 4757 可持久化trie树

  首先如果给定一些数,询问这些数中哪个数^给定的数的值最大的话,我们可以建立一颗trie树,根连接的两条边分别为0,1,表示二进制下第15位,那么我们可以建立一颗trie树,每一条从根到叶子节点的链表示一个2^16以内的数,开始每个节点的cnt都是0,那么每插入一个元素,将表示这个值的链上所有位置的cnt++,那么对于一个值要求使得^最大,如果这个值的某一位是1,那么我们最好要找到一个为0的数来和他^,那么判断下0儿子的cnt是不是大于0,然后做就好了。

  那么对于这棵树,我们可以先将1为根,然后对于两个点x,y之间的链拆成x,lca和y,lca的两条链,现在问题就转化为了求一个深度递增的链上所有值和给定值的^值最大,那么我们可以建立可持久化trie,每个节点继承父节点的trie树,我们只需要用x的trie树减去lca father的trie树做开始的问题就好了。

  反思:调试的时候输出调试的,然后答案更新的只按照一部分更新的,忘了改回去了。因为这个题没看题,是别人讲的题意,所以没看到多组数据,在这儿一直错= =。

//By BLADEVIL

#include <cstdio>

#include <cstring>

#include <algorithm>

#define maxn 200010



using namespace std;



struct ww {

    int son[2];

    int cnt;

    ww() {

        cnt=0;

        memset(son,0,sizeof son);

    }

}t[maxn<<5];



int n,m,l,tot;

int a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],que[maxn],dis[maxn],jump[maxn][20];



void connect(int x,int y) {

    pre[++l]=last[x];

    last[x]=l;

    other[l]=y;

}



int lca(int x,int y) {

    if (dis[x]>dis[y]) swap(x,y);

    int dep=dis[y]-dis[x];

    for (int i=0;i<=18;i++) if (dep&(1<<i)) y=jump[y][i];

    if (x==y) return x;

    for (int i=18;i>=0;i--) if (jump[x][i]!=jump[y][i]) x=jump[x][i],y=jump[y][i];

    return jump[x][0];

}



void build(int &x,int dep) {

    if (!x) x=++tot;

    if (dep<0) return ;

    build(t[x].son[0],dep-1); build(t[x].son[1],dep-1);

}    



void insert(int &x,int rot,int dep,int key) {

    if (!x) x=++tot;

    if (dep==-2) return ; 

    if (key&(1<<dep)) {

        insert(t[x].son[1],t[rot].son[1],dep-1,key);

        t[x].son[0]=t[rot].son[0];

    } else {

        insert(t[x].son[0],t[rot].son[0],dep-1,key);

        t[x].son[1]=t[rot].son[1];

    }

    t[x].cnt+=t[rot].cnt+1;

    //printf("|%d %d\n",t[x].cnt,x);

}



int query(int rx,int lx,int key,int dep) {

    if (dep==-2) return 0;

    //printf("%d %d %d %d\n",t[rx].son[1],t[rx].son[0],t[t[rx].son[1]].cnt,t[t[rx].son[0]].cnt);

    int ans=0;

    if (key&(1<<dep)) {

        if (t[t[rx].son[0]].cnt-t[t[lx].son[0]].cnt) {

            ans=1<<dep;

            ans+=query(t[rx].son[0],t[lx].son[0],key,dep-1);

        } else ans=query(t[rx].son[1],t[lx].son[1],key,dep-1);

    } else {

        if (t[t[rx].son[1]].cnt-t[t[lx].son[1]].cnt) {

            ans=1<<dep;

            ans+=query(t[rx].son[1],t[lx].son[1],key,dep-1);

        } else ans=query(t[rx].son[0],t[lx].son[0],key,dep-1);

    }

    //printf("%d\n",ans);

    return ans;

}



void work() {

     memset(t,0,sizeof t);

    memset(last,0,sizeof last);

    memset(dis,0,sizeof dis);

    tot=n; l=0;

    for (int i=1;i<=n;i++) scanf("%d",&a[i]);

    for (int i=1;i<n;i++) {

        int x,y; scanf("%d%d",&x,&y);

        connect(x,y); connect(y,x);

    }

    int head=0,tail=1; que[1]=1; dis[1]=1;

    while (head<tail) {

        int cur=que[++head];

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]) continue;

            que[++tail]=other[p]; dis[other[p]]=dis[cur]+1;

        }

    }

    //for (int i=1;i<=n;i++) printf(i==n?"%d\n":"%d ",que[i]);

    jump[1][0]=++tot;

    for (int i=1;i<=n;i++) 

        for (int p=last[que[i]];p;p=pre[p]) 

            if (dis[other[p]]>dis[que[i]]) jump[other[p]][0]=que[i];

    for (int i=1;i<=18;i++)

        for (int j=1;j<=n;j++){

            int cur=que[j];

            jump[cur][i]=jump[jump[cur][i-1]][i-1];

        }

    build(jump[1][0],15);

    for (int i=1;i<=n;i++) insert(que[i],jump[que[i]][0],15,a[que[i]]);

    //for (int i=1;i<=tot;i++) printf("%d %d %d %d\n",i,t[i].son[0],t[i].son[1],t[i].cnt);

    //int x,y; scanf("%d%d",&x,&y); printf("%d\n",lca(x,y));

    while (m--) {

        int x,y,z; scanf("%d%d%d",&x,&y,&z);

        int root=lca(x,y);

        int ans=0;

        ans=max(query(x,jump[root][0],z,15),query(y,jump[root][0],z,15));

        //ans=query(y,jump[root][0],z,15);

        printf("%d\n",ans);

    }

}



int main() {

    while (scanf("%d%d",&n,&m)!=EOF) work(); 

    return 0;

}

 

你可能感兴趣的:(trie)