hdu 5274 数链剖分 /dfs+数状数组

 注意 第一:点权为0.。。。  第 2:杭电扩展啊。。。  
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <iostream>
#include<stdio.h>
#include<cmath>
#include<string.h>
#include<algorithm>
#include<string>
using namespace std;
const int mmax = 100010;
const int inf=0x3fffffff;
struct edge
{
    int st,en;
    int next;
}E[2*mmax];
int p[mmax],fa[mmax],son[mmax],top[mmax],ID[mmax];
int deep[mmax],id_[mmax];
bool vis[mmax];
int w[mmax];
int num;
void add(int st,int en)
{
    E[num].st=st;
    E[num].en=en;
    E[num].next=p[st];
    p[st]=num++;
}
void init()
{
    memset(p,-1,sizeof p);
    num=0;
}
struct tree
{
    int l,r;
    int sum;
    int mid()
    {
        return (l+r)>>1;
    }
}T[4*mmax];
void build(int id,int l,int r)
{
    T[id].l=l,T[id].r=r;
    if(l==r)
    {
        T[id].sum=w[ID[l]];
        return ;
    }
    int mid=T[id].mid();
    build(id<<1,l,mid);
    build(id<<1|1,mid+1,r);
    T[id].sum=T[id<<1].sum^T[id<<1|1].sum;
}
void updata(int id,int pos,int val)
{
    if(T[id].l==T[id].r)
    {
        T[id].sum=val;
        return ;
    }
    int mid=T[id].mid();
    if(mid>=pos)
        updata(id<<1,pos,val);
    else
        updata(id<<1|1,pos,val);
    T[id].sum=T[id<<1].sum^T[id<<1|1].sum;
}
int query(int id,int l,int r)
{
    if(l<=T[id].l&&T[id].r<=r)
        return T[id].sum;
    int mid=T[id].mid();
    int ans=0;
    if(mid>=l)
        ans^=query(id<<1,l,r);
    if(mid<r)
        ans^=query(id<<1|1,l,r);
    return ans;
}

int dfs(int u)
{
    vis[u]=1;
    int cnt=1,tmp=0,e=0;
    for(int i=p[u];i+1;i=E[i].next)
    {
        int v=E[i].en;
        if(!vis[v])
        {
            fa[v]=u;
            deep[v]=deep[u]+1;
            int tt=dfs(v);
            cnt+=tt;
            if(tmp<tt)
            {
                tmp=tt;
                e=v;
            }
        }
    }
    son[u]=e;
    return cnt;
}
int now_cnt;
void new_id(int u)
{
    ID[now_cnt]=u;
    id_[u]=now_cnt;
    now_cnt++;
    vis[u]=1;
    if(son[u])
    {
        top[son[u]]=top[u];
        new_id(son[u]);
    }
    for(int i=p[u];i+1;i=E[i].next)
    {
        int v=E[i].en;
        if(!vis[v])
            new_id(v);
    }
}
int solve(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(deep[top[x]]<deep[top[y]])
            swap(x,y);
        ans^=query(1,id_[top[x]],id_[x]);
        x=fa[top[x]];
    }
    if(deep[x]>deep[y])
        swap(x,y);
    ans^=query(1,id_[x],id_[y]);
    return ans;
}


int main()
{
    int n,q;
    int t;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d %d",&n,&q);
        init();
        for(int i=0;i<n-1;i++)
        {
            int u,v;
            scanf("%d %d",&u,&v);
            add(u,v);
            add(v,u);
        }
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&w[i]);
            if(w[i]==0)
                w[i]=mmax;
        }
        fa[1]=1;
        deep[1]=0;
        memset(vis,0,sizeof vis);
        for(int i=1;i<=n;i++)
            top[i]=i;
        dfs(1);
        memset(vis,0,sizeof vis);
        now_cnt=1;
        new_id(1);
        build(1,1,n);
        while(q--)
        {
            int d,x,y;
            scanf("%d %d %d",&d,&x,&y);
            if(d==0)
            {
                if(y==0)
                    y=mmax;
                updata(1,id_[x],y);
            }
            else
            {
                if(x>y)
                    swap(x,y);
                int tmp=solve(x,y);
                if(tmp==0)
                    puts("-1");
                else
                {
                    if(tmp==mmax)
                        tmp=0;
                    printf("%d\n",tmp);
                }
            }
        }
    }

    return 0;
}


第2种写法  利用dfs序列


#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <iostream>
#include<stdio.h>
#include<cmath>
#include<string.h>
#include<algorithm>
#include<string>
using namespace std;
const int mmax = 100010;
const int inf=0x3fffffff;
struct edge
{
    int st,en;
    int next;
}E[2*mmax];
int p[mmax];
int w[mmax];
int num;
void add(int st,int en)
{
    E[num].st=st;
    E[num].en=en;
    E[num].next=p[st];
    p[st]=num++;
}
void init()
{
    memset(p,-1,sizeof p);
    num=0;
}
int Times;
int deep[mmax],First[mmax],Last[mmax];
int C[mmax];
int fa[mmax][20];
int low_bit(int x)
{
    return x&(-x);
}
int n;
void update(int x)
{
    for(int i=First[x];i<=n;i+=low_bit(i))
        C[i]^=w[x];
    for(int i=Last[x];i<=n;i+=low_bit(i))
        C[i]^=w[x];
}
int get_sum(int x)
{
    int fg=0;
    for(int i=First[x];i>0;i-=low_bit(i))
        fg^=C[i];
    return fg;
}
void dfs(int u,int Deep)
{
    Times++;
    First[u]=Times;
    deep[u]=Deep;
    for(int i=1;(1<<i)<=deep[u];i++)
        fa[u][i]=fa[ fa[u][i-1]][i-1];

    for(int i=p[u];i+1;i=E[i].next)
    {
        int v=E[i].en;
        if(deep[v]==-1)
        {
            fa[v][0]=u;
            dfs(v,Deep+1);
        }
    }
    Last[u]=Times+1;
}


int lca(int x,int y)
{
    if(deep[x]<deep[y])
        swap(x,y);
    for(int i=19;i>=0;i--)
    {
        if(fa[x][i]!=-1 && deep[fa[x][i]]>=deep[y])
            x=fa[x][i];
        if(deep[x]==deep[y])
            break;
    }
    if(x==y)
        return x;
    for(int i=19;i>=0;i--)
    {
        if(fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}
int main()
{
    int q;
    int t;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d %d",&n,&q);
        init();
        for(int i=0;i<n-1;i++)
        {
            int u,v;
            scanf("%d %d",&u,&v);
            add(u,v);
            add(v,u);
        }
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&w[i]);
            w[i]++;
        }
        Times=0;
        memset(deep,-1,sizeof deep);
        memset(fa,-1,sizeof fa);
        dfs(1,0);
        //cout<<lca(1,2)<<endl;
        //system("pause");

       // for(int i=1;i<=n;i++)
            //cout<<First[i]<<" "<<Last[i]<<endl;

        memset(C,0,sizeof C);
        for(int i=1;i<=n;i++)
            update(i);
//        for(int i=1;i<=n;i++)
//            cout<<get_sum(i)<<" ";
//        cout<<endl;
        while(q--)
        {
            int d,x,y;
            scanf("%d %d %d",&d,&x,&y);
            if(d==0)
            {
                update(x);
                w[x]=(++y);
                update(x);
            }
            else
            {
                //cout<<lca(x,y)<<endl;
                int tmp=get_sum(x)^get_sum(y)^w[lca(x,y)];
                tmp--;
                printf("%d\n",tmp);
            }
        }

    }

    return 0;
}


你可能感兴趣的:(hdu 5274 数链剖分 /dfs+数状数组)