找到树的直径,然后分别从直径两个端点建立倍增数组实现查找。
#include <bits/stdc++.h> #define pb push_back #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 typedef long long LL; using namespace std; const int maxn = 20010; vector<int>G[maxn]; int dp[maxn][20][2]; int pre[maxn][20][2]; int dis[maxn][20][2]; bool vis[maxn]; int l,r,n; void bfs() { queue<pair<int,int> >Q; Q.push( make_pair(1,0) ); memset(vis,0,sizeof(vis)); vis[1] = 1; int num = 0,ans; while(!Q.empty()) { int a = Q.front().first; int b = Q.front().second; if(b>num){ num = b; ans = a; } Q.pop(); for(int i=0;i<G[a].size();i++) { int v = G[a][i]; if(!vis[v]) { vis[v] = 1; Q.push(make_pair(v,b+1)); } } } l = ans; memset(vis,0,sizeof(vis)); vis[l] = 1; Q.push(make_pair(l,0)); num = 0; while(!Q.empty()) { int a = Q.front().first; int b = Q.front().second; if(b>num){ num = b; ans = a; } Q.pop(); for(int i=0;i<G[a].size();i++) { int v = G[a][i]; if(!vis[v]) { vis[v] = 1; Q.push(make_pair(v,b+1)); } } } r = ans; } void dfs(int x,int fa,int de,int flag) { vis[x] = 1; pre[x][0][flag] = fa; for(int i=0;i<G[x].size();i++) if(G[x][i]!=fa) { dis[ G[x][i] ][0][flag] = 1; dfs(G[x][i],x,de+1,flag); } } void init() { memset(pre,-1,sizeof(pre)); memset(dis,0,sizeof(dis)); memset(vis,0,sizeof(vis)); dfs(l,l,1,0); for(int k=1; k<20; ++k) for(int i=1; i<=n; ++i) if(pre[i][k-1][0]!=-1) { pre[i][k][0] = pre[ pre[i][k-1][0] ][k-1][0]; dis[i][k][0] = dis[i][k-1][0] + dis[ pre[i][k-1][0] ][k-1][0]; } memset(vis,0,sizeof(vis)); dfs(r,r,1,1); for(int k=1; k<20; ++k) for(int i=1; i<=n; ++i) if(pre[i][k-1][1]!=-1) { pre[i][k][1] = pre[ pre[i][k-1][1] ][k-1][1]; dis[i][k][1] = dis[i][k-1][1] + dis[ pre[i][k-1][1] ][k-1][1]; } } void solve(int x,int d) { if(d==0) { printf("%d\n",x); return ; } int ansd = d,y=x; while(1) { int j = -1; for(int k=0;k<20;k++) if(dis[y][k][0]>=ansd) { j = k; break; } if(j==-1) break; if(dis[y][j][0]==ansd) { printf("%d\n",pre[y][j][0]); return ; } else { ansd -= dis[y][j-1][0]; y = pre[y][j-1][0]; } } ansd = d,y=x; while(1) { int j = -1; for(int k=0;k<20;k++) if(dis[y][k][1]>=ansd) { j = k; break; } if(j==-1) break; if(dis[y][j][1]==ansd) { printf("%d\n",pre[y][j][1]); return ; } else { ansd -= dis[y][j-1][1]; y = pre[y][j-1][1]; } } printf("0\n"); } int main() { int Q,x,y,val,d; while(scanf("%d%d",&n,&Q)!=EOF) { for(int i=0;i<=n;i++) G[i].clear(); for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); G[x].pb(y); G[y].pb(x); } bfs(); init(); while(Q--) { scanf("%d%d",&val,&d); solve(val,d); } } return 0; }