题目链接~~>
做题感悟:感觉又充实了一些。
解题思路:树链剖分 + 倍增
开始看时,第一问还好,第二问就不知道怎么解了。其实这两问都可以用倍增法解决。
先解释一下我理解的倍增 :记录 u 结点的 第 2 ^ i 个祖先,然后求u 的第 k 个祖先的时候,就相当于用 2 ^ i 去组合 k ,不断向上,一直到达第 k 个节点,其实每次更新的时k 的二进制中为 1 的位置。如下图,计算 u 的第 5 个祖先结点(这里不包括 u),先到达 u' 节点,然后再从 u' ,到 u'' (5 的二进制 101) 。会倍增算法后就好做了,计算第一问的时候 dis = dis[ u ] + dis[v] - 2 * dis[ LCA(u ,v)] ,第二问先判断一下要求的点在 u 到交点的链上还是在 v 到交点的链上然后再结合倍增做就ok了。
代码:
#include<iostream> #include<sstream> #include<map> #include<cmath> #include<fstream> #include<queue> #include<vector> #include<sstream> #include<cstring> #include<cstdio> #include<stack> #include<bitset> #include<ctime> #include<string> #include<cctype> #include<iomanip> #include<algorithm> using namespace std ; #define INT long long int #define L(x) (x * 2) #define R(x) (x * 2 + 1) const int INF = 0x3f3f3f3f ; const double esp = 0.0000000001 ; const double PI = acos(-1.0) ; const INT mod = 1000000007 ; const int MY = 1400 + 5 ; const int MX = 20000 + 5 ; int num ,S = 20 ,n ; int head[MX] ,dep[MX] ,dis[MX] ,p[MX][30] ; struct NODE { int v ,w ,next ; }E[MX] ; void addedge(int u ,int v ,int w) { E[num].v = v ; E[num].w = w ; E[num].next = head[u] ; head[u] = num++ ; E[num].v = u ; E[num].w = w ; E[num].next = head[v] ; head[v] = num++ ; } void dfs_find(int u ,int fa ,int w) // 处理深度、距离 { dep[u] = dep[fa] + 1 ; dis[u] = w ; p[u][0] = fa ; for(int i = 1 ;i <= S ; ++i) // 处理祖先 p[u][i] = p[p[u][i-1]][i-1] ; for(int i = head[u] ;i != -1 ;i = E[i].next) { int v = E[i].v ; if(v == fa) continue ; dfs_find(v ,u ,w + E[i].w) ; } } int LCA(int u ,int v) // 计算公共交点 { if(dep[u] > dep[v]) swap(u ,v) ; // u 的深度小于等于 v if(dep[u] < dep[v]) // 处理成同一深度 { int d = dep[v] - dep[u] ; // 深度差 for(int i = 0 ;i < S ; ++i) if(d&(1<<i)) v = p[v][i] ; } if(u != v) // 已经变成同一深度 { for(int i = S ;i >= 0 ; --i) if(p[u][i] != p[v][i]) { u = p[u][i] ; v = p[v][i] ; } u = p[u][0] ; v = p[v][0] ; } return u ; } int cunt(int u ,int k) // 计算 u 的第 k 个节点 { for(int i = 0 ;i < S ; ++i) if(k&(1<<i)) u = p[u][i] ; return u ; } int Query(int u ,int v ,int k) // 从 u 到 v 的路径上的第 k 个节点 { int z = LCA(u ,v) ; // 公共交点 if(dep[u] - dep[z] + 1 >= k) // 在 u 的这条链上 return cunt(u ,k-1) ; else // 在 v 的这条线上 { k -= dep[u] - dep[z] ; k = dep[v] - dep[z] - k + 1 ; return cunt(v ,k) ; } } int main() { //freopen("input.txt" ,"r" ,stdin) ; char s[10] ; int Tx ,u ,v ,w ,k ; scanf("%d" ,&Tx) ; while(Tx--) { scanf("%d" ,&n) ; num = 0 ; memset(head ,-1 ,sizeof(head)) ; for(int i = 1 ;i < n ; ++i) { scanf("%d%d%d" ,&u ,&v ,&w) ; addedge(u ,v ,w) ; } dep[1] = 0 ; dfs_find(1 ,1 ,0) ; while(scanf("%s" ,s) && strcmp(s ,"DONE")) { if(s[0] == 'D') // 求任意两点之间的距离 { scanf("%d%d" ,&u ,&v) ; printf("%d\n" ,dis[u] + dis[v] - 2 *dis[LCA(u ,v)]) ; } else // 询问第 k 个节点 { scanf("%d%d%d" ,&u ,&v ,&k) ; printf("%d\n" ,Query(u ,v ,k)) ; } } } return 0 ; }