这道题目本来应该1A的。。结果中间调试输出忘了删WA了一发,又PE了一发。。/(ㄒoㄒ)/~~提示里面怎么没有不能有文尾回车!!。。。
首先跑dfs,一般情况下主席树的第i颗线段树应该是根据第i-1颗建立的;但是在树中,第i颗线段树应该在第fa[i]颗的基础上建立。其余一模一样。
查询也略有区别,首先根据dfs可以logN求lca。在查询[x,y]的第k大的数时,令u=lca(x,y),v=fa[u],那么在树上x-y的路径相当于sum[x]+sum[y]-sum[u]-sum[v]。其余一模一样。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 100005 #define M 2000005 using namespace std; int n,m,cnt,tot,trtot,dfsclk,d[N],fa[N][17],a[N],num[N],mi2[25],hash[N]; int rt[N],ls[M],rs[M],sum[M],fst[N],pnt[N<<1],nxt[N<<1]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int aa,int bb){ pnt[++tot]=bb; nxt[tot]=fst[aa]; fst[aa]=tot; } int find(int x){ int l=1,r=cnt,mid; while (l<r){ mid=(l+r)>>1; if (hash[mid]<x) l=mid+1; else r=mid; } return l; } int lca(int x,int y){ if (d[x]<d[y]) swap(x,y); int i,tmp=d[x]-d[y]; for (i=0; i<=16; i++) if (tmp&mi2[i]) x=fa[x][i]; for (i=16; i>=0; i--) if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; } return (x==y)?x:fa[x][0]; } void ins(int l,int r,int x,int &y,int v){ y=++trtot; sum[y]=sum[x]+1; int mid=(l+r)>>1; if (l==r) return; if (v<=mid){ rs[y]=rs[x]; ins(l,mid,ls[x],ls[y],v); } else{ ls[y]=ls[x]; ins(mid+1,r,rs[x],rs[y],v); } } int qry(int x,int y,int k){ int u=lca(x,y),v=fa[u][0],l=1,r=cnt; x=rt[x]; y=rt[y]; u=rt[u]; v=rt[v]; while (l<r){ int mid=(l+r)>>1,tmp=sum[ls[x]]+sum[ls[y]]-sum[ls[u]]-sum[ls[v]]; if (tmp>=k){ x=ls[x]; y=ls[y]; u=ls[u]; v=ls[v]; r=mid; } else{ k-=tmp; x=rs[x]; y=rs[y]; u=rs[u]; v=rs[v]; l=mid+1; } } return hash[l]; } void dfs(int x,int last){ d[x]=d[last]+1; fa[x][0]=last; int p,i; ins(1,cnt,rt[fa[x][0]],rt[x],a[x]); for (i=1; mi2[i]<=d[x]; i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (y!=last) dfs(y,x); } } int main(){ n=read(); m=read(); int i; for (i=1; i<=n; i++) a[i]=num[i]=read(); sort(num+1,num+n+1); for (i=1; i<=n; i++) if (i==1 || num[i]!=num[i-1]) hash[++cnt]=num[i]; for (i=1; i<=n; i++) a[i]=find(a[i]); for (i=1; i<n; i++){ int x=read(),y=read(); add(x,y); add(y,x); } mi2[0]=1; for (i=0; mi2[i]<=n; i++) mi2[i+1]=mi2[i]<<1; dfs(1,0); int ans=0; while (m--){ int x=read()^ans,y=read(),k=read(); printf("%d",ans=qry(x,y,k)); if (m) puts(""); } return 0; }
by lych
2016.2.12