题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4822
本题抽象的题意是给出一棵树,有许多询问,每次询问,给出3个点,问有多少个点,到这三个点的最短距离是递增的。
首先考虑两个点的简单情况,因为是树,有特殊性,任意两点间只有唯一的一条路,找到路的中点,就可以把树分成两部分,其中一部分的点是合法解。
回到本题,问题就变成了两个子树的交集。这个考虑一个子树是否是另一子树的子树即可。用dfs序列来判断即可。
时间复杂度是O(nlogn)
1 #ifdef ONLINE_JUDGE 2 #pragma comment(linker, "/STACK:1024000000,1024000000") 3 #endif // ONLINE_JUDGE 4 5 #include <cstdio> 6 #include <cstring> 7 #include <algorithm> 8 #include <iostream> 9 using namespace std; 10 11 const int MAXV = 100010; 12 const int MAXE = 200010; 13 const int MAX_LOG = 20; 14 15 int head[MAXV], ecnt; 16 int to[MAXE], next[MAXE]; 17 int n, m, T; 18 19 void init() { 20 memset(head + 1, -1, n * sizeof(int)); 21 ecnt = 0; 22 } 23 24 void add_edge(int u, int v) { 25 to[ecnt] = v; next[ecnt] = head[u]; head[u] = ecnt++; 26 to[ecnt] = u; next[ecnt] = head[v]; head[v] = ecnt++; 27 } 28 29 int fa[MAX_LOG][MAXV]; 30 int size[MAXV], dep[MAXV]; 31 32 void dfs(int u, int f, int depth) { 33 fa[0][u] = f; size[u] = 1; dep[u] = depth; 34 for(int p = head[u]; ~p; p = next[p]) { 35 int v = to[p]; 36 if(v == f) continue; 37 dfs(v, u, depth + 1); 38 size[u] += size[v]; 39 } 40 } 41 42 void initfa() { 43 dfs(1, -1, 0); 44 for(int k = 0; k < MAX_LOG - 1; ++k) { 45 for(int u = 1; u <= n; ++u) { 46 if(fa[k][u] == -1) fa[k + 1][u] = 1; 47 else fa[k + 1][u] = fa[k][fa[k][u]]; 48 } 49 } 50 } 51 52 int upslope(int u, int p) { 53 for(int k = 0; k < MAX_LOG; ++k) { 54 if((p >> k) & 1) u = fa[k][u]; 55 } 56 return u; 57 } 58 59 int lca(int u, int v) { 60 if(dep[u] < dep[v]) swap(u, v); 61 u = upslope(u, dep[u] - dep[v]); 62 if(u == v) return u; 63 for(int k = MAX_LOG - 1; k >= 0; --k) { 64 if(fa[k][u] != fa[k][v]) 65 u = fa[k][u], v = fa[k][v]; 66 } 67 return fa[0][u]; 68 } 69 70 struct Node { 71 int type, r; 72 Node(int type, int r): type(type), r(r) {} 73 }; 74 75 Node get_middle(int a, int b, int ab) { 76 int len = dep[a] + dep[b] - 2 * dep[ab]; 77 if(dep[a] >= dep[b]) { 78 return Node(1, upslope(a, (len - 1) / 2)); 79 } else { 80 return Node(2, upslope(b, len / 2)); 81 } 82 } 83 84 int calc(int a, int b, int c, int ab, int ac) { 85 Node bn = get_middle(a, b, ab), cn = get_middle(a, c, ac); 86 if(bn.type == 1 && cn.type == 1) { 87 if(dep[bn.r] < dep[cn.r]) swap(bn, cn); 88 if(lca(bn.r, cn.r) == cn.r) return size[bn.r]; 89 else return 0; 90 } else if(bn.type == 2 && cn.type == 2) { 91 if(dep[bn.r] < dep[cn.r]) swap(bn, cn); 92 if(lca(bn.r, cn.r) == cn.r) return n - size[cn.r]; 93 else return n - size[bn.r] - size[cn.r]; 94 } else { 95 if(bn.type == 2) swap(bn, cn); 96 int t = lca(bn.r, cn.r); 97 if(t == cn.r) return n - size[cn.r]; 98 if(t == bn.r) return size[bn.r] - size[cn.r]; 99 return size[bn.r]; 100 } 101 } 102 103 int main() { 104 scanf("%d", &T); 105 while(T--) { 106 scanf("%d", &n); 107 init(); 108 for(int i = 1, u, v; i < n; ++i) { 109 scanf("%d%d", &u, &v); 110 add_edge(u, v); 111 } 112 initfa(); 113 scanf("%d", &m); 114 for(int i = 0, a, b, c; i < m; ++i) { 115 scanf("%d%d%d", &a, &b, &c); 116 int ab = lca(a, b), ac = lca(a, c), bc = lca(b, c); 117 printf("%d %d %d\n", calc(a, b, c, ab, ac), calc(b, a, c, ab, bc), calc(c, a, b, ac, bc)); 118 } 119 } 120 }