//count the tree
/*
思路:每个点建立一颗线段树(增量建立),以遍历的时间为序,充分利用上一颗线段树的信息,在这题上一颗线段树就是父节点的线段树,因为我们每次更新的信息只有一个节点,一个节点被更新了,那么它的所有祖先节点也要相应的被更新,又因为在线段树中一个节点的祖先节点数不会超过(logN)个,所以这颗线段树和上一颗线段树大部分节点是一样的,只有刚刚说的logN个节点被改变了,所以我们只要保存这改变的logN个节点就可以了,因此整个程序的空间复杂度就是O(NlogN)。
在这题里面,因为要查找的两个点a,b的这两颗线段树的信息,都是从lca(a,b)这个点里面的信息改变来的,所以其中的信息重复了一次;
所以在后面判断第k个数是在左子树上还是在右子树上时要减掉lca(a,b)这个点之前的信息,接下来的处理就和求线性区间第k大
一样的思路了。
*/
实现代码:
#include<stdio.h> #include<string.h> #include<algorithm> using namespace std; #define FN 2010000 #define N 101000 int tot,T[N],L[FN],R[FN],C[FN],pa[N][20],n; int build(int l,int r) { int mid,rt=++tot; if(l==r) return rt; mid=(l+r)>>1; L[rt]=build(l,mid); R[rt]=build(mid+1,r); return rt; } int update(int rt,int p,int d) { int newrt=++tot,l=1,r=n,mid,root=newrt; C[newrt]=C[rt]+d; while(r>l){ mid=(l+r)>>1; if(p<=mid){ L[newrt]=++tot;R[newrt]=R[rt]; rt=L[rt];newrt=L[newrt];r=mid; } else{ R[newrt]=++tot;L[newrt]=L[rt]; rt=R[rt];newrt=R[newrt];l=mid+1; } C[newrt]=C[rt]+d; } return root; } struct node{ int id,v,sd; }A[N]; int query(int t1,int t2,int lca,int k) { int l=1,r=n,mid,p=T[pa[lca][0]],t;lca=T[lca]; while(r>l){ mid=(l+r)>>1; t=C[L[t1]]+C[L[t2]]-C[L[lca]]-C[L[p]]; if(t>=k){ t1=L[t1];t2=L[t2]; lca=L[lca];p=L[p];r=mid; } else{ k-=t;t1=R[t1];t2=R[t2]; lca=R[lca];p=R[p];l=mid+1; } }return A[l].v; } struct E{ int t,next; }edge[2*N]; int ant,head[N],H[N]; void add(int a,int b) { edge[ant].t=b; edge[ant].next=head[a]; head[a]=ant++; } void dfs(int rt,int p,int dep) { int i; pa[rt][0]=p;H[rt]=dep; T[rt]=update(T[p],A[rt].sd,1); for(i=head[rt];i!=-1;i=edge[i].next) { if(edge[i].t==p) continue; dfs(edge[i].t,rt,dep+1); } } bool cmp1(node a,node b){ return a.v<b.v; } bool cmp2(node a,node b){ return a.id<b.id;} int B[N]; int Lca(int x,int y) { int k; if(x==y) return x; if(H[x]<H[y]) swap(x,y); for(k=B[H[x]-H[y]];k>=0;--k) if(H[x]-H[y]>=(1<<k)) x=pa[x][k]; if(x==y) return x; for(k=B[H[x]];k>=0;--k) { if(pa[x][k]&&pa[x][k]!=pa[y][k]) x=pa[x][k],y=pa[y][k]; } return pa[x][0]; } int main() { int m,i,a,b,k,lca; for(i=1;i<=N;i++) { B[i]=0;while(i>=(1<<B[i])) B[i]++; } while(scanf("%d%d",&n,&m)!=EOF) { tot=ant=0; memset(head,-1,sizeof(head)); for(i=1;i<=n;i++){ A[i].id=i; scanf("%d",&A[i].v); } T[0]=build(1,n); for(i=1;i<n;i++) { scanf("%d%d",&a,&b); add(a,b);add(b,a); } sort(A+1,A+n+1,cmp1); for(i=1;i<=n;i++) A[i].sd=i; sort(A+1,A+n+1,cmp2); dfs(1,0,0); sort(A+1,A+n+1,cmp1); for(k=1;k<20;k++) for(i=1;i<=n;i++) if(pa[i][k-1]) pa[i][k]=pa[pa[i][k-1]][k-1]; while(m--) { scanf("%d%d%d",&a,&b,&k);lca=Lca(a,b); printf("%d\n",query(T[a],T[b],lca,k)); } } return 0; }