SPOJ Problem Set (classical)SPOJ 913. Query on a tree IIProblem code: QTREE2 |
You are given a tree (an undirected acyclic connected graph) with N nodes, and edges numbered 1, 2, 3...N-1. Each edge has an integer value assigned to it, representing its length.
We will ask you to perfrom some instructions of the following form:
Example:
N = 6
1 2 1 // edge connects node 1 and node 2 has cost 1
2 4 1
2 5 2
1 3 1
3 6 2
Path from node 4 to node 6 is 4 -> 2 -> 1 -> 3 -> 6
DIST 4 6 : answer is 5 (1 + 1 + 1 + 2 = 5)
KTH 4 6 4 : answer is 3 (the 4-th node on the path from node 4 to node 6 is 3)
The first line of input contains an integer t, the number of test cases (t <= 25). t test cases follow.
For each test case:
There is one blank line between successive tests.
For each "DIST" or "KTH" operation, write one integer representing its result.
Print one blank line after each test.
Input: 1 6 1 2 1 2 4 1 2 5 2 1 3 1 3 6 2 DIST 4 6 KTH 4 6 4 DONE Output: 5 3
-------------------------------------------------------------------
题目大意:给定一颗有边权的树,有两种操作,DIST(i,j)操作询问i和j节点之间的距离,KTH(i,j,k)操作询问i和j节点之间第k个节点的编号。
解题思路:利用树上的倍增就可以搞定。每个节点都保存它的第2^i的父亲。对于DIST询问,只要利用倍增求出lca,然后减一减就好了。对于KTH询问,先求出lca,然后判断是第一个点到lca的路径上还是第二个点到lca的路径上。哎算是水题,不过第一次用倍增,RE好久。
#include <stdio.h> #include <string.h> #include <vector> #define clr(a,b) memset(a,b,sizeof(a)) using namespace std; const int N=200005; int n,eid; int head[N],ed[N<<1],val[N<<1],nxt[N<<1]; vector<int>fa[N]; int sta[N],top,dep[N],dis[N]; void addedge(int s,int e,int v){ ed[eid]=e;val[eid]=v;nxt[eid]=head[s];head[s]=eid++; } void dfs(int s,int f,int d,int ds){ fa[s].clear();int k=1;dep[s]=d;dis[s]=ds; while(top-k>=0){ fa[s].push_back(sta[top-k]);k*=2; } sta[top++]=s; for(int i=head[s];~i;i=nxt[i]){ int e=ed[i],v=val[i]; if(e!=f)dfs(e,s,d+1,ds+v); } top--; } int lca(int a,int b){ if(a==b)return a; if(dep[b]>dep[a])swap(a,b); while(dep[a]>dep[b]){ int len=fa[a].size(),le=0,ri=len,mid; while(mid=(le+ri)>>1,ri>le){ if(dep[fa[a][mid]]>=dep[b])le=mid+1; else ri=mid; } a=fa[a][ri-1]; } if(a==b)return a; while(1){ int len=fa[a].size(),le=0,ri=len,mid; while(mid=(le+ri)>>1,ri>le){ if(fa[a][mid]!=fa[b][mid])le=mid+1; else ri=mid; } if(ri==0)return fa[a][ri]; a=fa[a][ri-1];b=fa[b][ri-1]; } return a; } int kth(int a,int b,int k){ int r=lca(a,b); if(dep[a]-dep[r]+1>=k){ int u=dep[a]-k+1; while(1){ if(u==dep[a])return a; int len=fa[a].size(),le=0,ri=len,mid; while(mid=(le+ri)>>1,ri>le){ if(dep[fa[a][mid]]>=u)le=mid+1; else ri=mid; } a=fa[a][ri-1]; } } else{ int u=k-dep[a]+dep[r]*2-1; while(1){ if(u==dep[b])return b; int len=fa[b].size(),le=0,ri=len,mid; while(mid=(le+ri)>>1,ri>le){ if(dep[fa[b][mid]]>=u)le=mid+1; else ri=mid; } b=fa[b][ri-1]; } } } int main(){ // freopen("/home/axorb/in","r",stdin); int T;scanf("%d",&T); while(T--){ eid=0;clr(head,-1);scanf("%d",&n); for(int i=1;i<n;i++){ int a,b,c;scanf("%d%d%d",&a,&b,&c); addedge(a,b,c);addedge(b,a,c); } top=0;dfs(1,-1,1,0); // for(int i=1;i<=n;i++)printf("%d %d %d\n",i,fa[i].size(),dep[i]); char ss[20]; while(scanf("%s",ss),ss[1]!='O') if(ss[1]=='I'){ int a,b;scanf("%d%d",&a,&b); int r=lca(a,b); printf("%d\n",dis[a]+dis[b]-2*dis[r]); } else{ int a,b,c;scanf("%d%d%d",&a,&b,&c); printf("%d\n",kth(a,b,c)); } puts(""); } }