#pragma comment(linker, "/STACK:1024000000,1024000000") #include <iostream> #include<stdio.h> #include<cmath> #include<string.h> #include<algorithm> #include<string> using namespace std; const int mmax = 100010; const int inf=0x3fffffff; struct edge { int st,en; int next; }E[2*mmax]; int p[mmax],fa[mmax],son[mmax],top[mmax],ID[mmax]; int deep[mmax],id_[mmax]; bool vis[mmax]; int w[mmax]; int num; void add(int st,int en) { E[num].st=st; E[num].en=en; E[num].next=p[st]; p[st]=num++; } void init() { memset(p,-1,sizeof p); num=0; } struct tree { int l,r; int sum; int mid() { return (l+r)>>1; } }T[4*mmax]; void build(int id,int l,int r) { T[id].l=l,T[id].r=r; if(l==r) { T[id].sum=w[ID[l]]; return ; } int mid=T[id].mid(); build(id<<1,l,mid); build(id<<1|1,mid+1,r); T[id].sum=T[id<<1].sum^T[id<<1|1].sum; } void updata(int id,int pos,int val) { if(T[id].l==T[id].r) { T[id].sum=val; return ; } int mid=T[id].mid(); if(mid>=pos) updata(id<<1,pos,val); else updata(id<<1|1,pos,val); T[id].sum=T[id<<1].sum^T[id<<1|1].sum; } int query(int id,int l,int r) { if(l<=T[id].l&&T[id].r<=r) return T[id].sum; int mid=T[id].mid(); int ans=0; if(mid>=l) ans^=query(id<<1,l,r); if(mid<r) ans^=query(id<<1|1,l,r); return ans; } int dfs(int u) { vis[u]=1; int cnt=1,tmp=0,e=0; for(int i=p[u];i+1;i=E[i].next) { int v=E[i].en; if(!vis[v]) { fa[v]=u; deep[v]=deep[u]+1; int tt=dfs(v); cnt+=tt; if(tmp<tt) { tmp=tt; e=v; } } } son[u]=e; return cnt; } int now_cnt; void new_id(int u) { ID[now_cnt]=u; id_[u]=now_cnt; now_cnt++; vis[u]=1; if(son[u]) { top[son[u]]=top[u]; new_id(son[u]); } for(int i=p[u];i+1;i=E[i].next) { int v=E[i].en; if(!vis[v]) new_id(v); } } int solve(int x,int y) { int ans=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]]) swap(x,y); ans^=query(1,id_[top[x]],id_[x]); x=fa[top[x]]; } if(deep[x]>deep[y]) swap(x,y); ans^=query(1,id_[x],id_[y]); return ans; } int main() { int n,q; int t; scanf("%d",&t); while(t--) { scanf("%d %d",&n,&q); init(); for(int i=0;i<n-1;i++) { int u,v; scanf("%d %d",&u,&v); add(u,v); add(v,u); } for(int i=1;i<=n;i++) { scanf("%d",&w[i]); if(w[i]==0) w[i]=mmax; } fa[1]=1; deep[1]=0; memset(vis,0,sizeof vis); for(int i=1;i<=n;i++) top[i]=i; dfs(1); memset(vis,0,sizeof vis); now_cnt=1; new_id(1); build(1,1,n); while(q--) { int d,x,y; scanf("%d %d %d",&d,&x,&y); if(d==0) { if(y==0) y=mmax; updata(1,id_[x],y); } else { if(x>y) swap(x,y); int tmp=solve(x,y); if(tmp==0) puts("-1"); else { if(tmp==mmax) tmp=0; printf("%d\n",tmp); } } } } return 0; }
第2种写法 利用dfs序列
#pragma comment(linker, "/STACK:1024000000,1024000000") #include <iostream> #include<stdio.h> #include<cmath> #include<string.h> #include<algorithm> #include<string> using namespace std; const int mmax = 100010; const int inf=0x3fffffff; struct edge { int st,en; int next; }E[2*mmax]; int p[mmax]; int w[mmax]; int num; void add(int st,int en) { E[num].st=st; E[num].en=en; E[num].next=p[st]; p[st]=num++; } void init() { memset(p,-1,sizeof p); num=0; } int Times; int deep[mmax],First[mmax],Last[mmax]; int C[mmax]; int fa[mmax][20]; int low_bit(int x) { return x&(-x); } int n; void update(int x) { for(int i=First[x];i<=n;i+=low_bit(i)) C[i]^=w[x]; for(int i=Last[x];i<=n;i+=low_bit(i)) C[i]^=w[x]; } int get_sum(int x) { int fg=0; for(int i=First[x];i>0;i-=low_bit(i)) fg^=C[i]; return fg; } void dfs(int u,int Deep) { Times++; First[u]=Times; deep[u]=Deep; for(int i=1;(1<<i)<=deep[u];i++) fa[u][i]=fa[ fa[u][i-1]][i-1]; for(int i=p[u];i+1;i=E[i].next) { int v=E[i].en; if(deep[v]==-1) { fa[v][0]=u; dfs(v,Deep+1); } } Last[u]=Times+1; } int lca(int x,int y) { if(deep[x]<deep[y]) swap(x,y); for(int i=19;i>=0;i--) { if(fa[x][i]!=-1 && deep[fa[x][i]]>=deep[y]) x=fa[x][i]; if(deep[x]==deep[y]) break; } if(x==y) return x; for(int i=19;i>=0;i--) { if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; } return fa[x][0]; } int main() { int q; int t; scanf("%d",&t); while(t--) { scanf("%d %d",&n,&q); init(); for(int i=0;i<n-1;i++) { int u,v; scanf("%d %d",&u,&v); add(u,v); add(v,u); } for(int i=1;i<=n;i++) { scanf("%d",&w[i]); w[i]++; } Times=0; memset(deep,-1,sizeof deep); memset(fa,-1,sizeof fa); dfs(1,0); //cout<<lca(1,2)<<endl; //system("pause"); // for(int i=1;i<=n;i++) //cout<<First[i]<<" "<<Last[i]<<endl; memset(C,0,sizeof C); for(int i=1;i<=n;i++) update(i); // for(int i=1;i<=n;i++) // cout<<get_sum(i)<<" "; // cout<<endl; while(q--) { int d,x,y; scanf("%d %d %d",&d,&x,&y); if(d==0) { update(x); w[x]=(++y); update(x); } else { //cout<<lca(x,y)<<endl; int tmp=get_sum(x)^get_sum(y)^w[lca(x,y)]; tmp--; printf("%d\n",tmp); } } } return 0; }