题目传送门 : http://acm.hdu.edu.cn/showproblem.php?pid=4822
给定一棵树,和树上节点 A A 、 B B 、 C C ,若节点 X X 到节点 A A 的距离严格小于到 B B 、 C C 的距离,那么称 X X 被 A A 占有。有若干询问,每次给定一组 A A 、 B B 、 C C ,问各占领的节点数。
若此题三个点减为两个点 A A 、 B B ,那么题目应该不难,我们平分 A−>B A − > B 的路径,分下奇偶,一半归 A A ,一半归 B B ,应当可以在 O((n+m)log2n) O ( ( n + m ) l o g 2 n ) 解决。口糊讨论一下做法:首先倍增求 lca l c a ,然后判断一下两个点离最近公共祖先的距离,同样用lca跳到 A−>B A − > B 路径的终点,最后分下奇偶,按子树的 size s i z e 算一算就完事了。
但加到三个点,好像一切都变复杂了啊……
咳,别颓了,有办法的。
如果说,我们仅仅只关注 A A 点占有的点,那么这些点会是上边问题分为 A,B A , B ,和 A,C A , C 所得点集的交集。这挺显然,但是看起来好像不容易实现啊。毕竟上边两个点的做法可不维护点集。如何维护点集?我们仔细观察发现,每个点集,要么是某颗子树,要么是整棵树减去某颗子树。所以我们自然就想到了 DFS D F S 序。如果按照 DFS D F S 序排序,那么一棵子树中的点在这个序列上连续。
而且此题甚至用不着生成这个序列!因为我们只需要知道这个子树所对应的区间( [DFN[x],DFN[x]+size[x]−1] [ D F N [ x ] , D F N [ x ] + s i z e [ x ] − 1 ] ),以及是整棵树删除这个子树还是选中这颗子树就可以了!剩下的就大力分类讨论一波就完事了。
听dalao说这题如果拓展到选若干个点,可以用非常dark优秀的虚树做。然而我并不会所以就不介绍了
#include
#include
#include
#include
#include
#include
using namespace std;
const int MAXN = 100010, MAXLOG = 18;
int n, m, x, y, z;
int lp, f[MAXN], lin[MAXN << 1], nxt[MAXN << 1];
inline void add(int x, int y) { lin[++lp] = y; nxt[lp] = f[x]; f[x] = lp; return; }
int size[MAXN], dfn[MAXN], timeset;
int d[MAXN][MAXLOG];
int deep[MAXN];
void clean_up() {
lp = 0;
memset(deep, 0, sizeof(deep));
memset(f, 0, sizeof(f));
memset(size, 0, sizeof(size));
memset(dfn, 0, sizeof(dfn));
memset(d, 0, sizeof(d));
timeset = 0;
return;
}
void build_tree(int pos, int fa) {
deep[pos] = deep[fa] + 1;
dfn[pos] = ++timeset;
d[pos][0] = fa;
size[pos] = 1;
for(int i = 1; i < MAXLOG; i++) d[pos][i] = d[d[pos][i - 1]][i - 1];
for(int t = f[pos]; t; t = nxt[t]) {
if(lin[t] == fa) continue;
build_tree(lin[t], pos);
size[pos] += size[lin[t]];
}
return;
}
int get_lca(int x, int y) {//求两个点的lca
if(deep[x] < deep[y]) swap(x, y);
for(int i = MAXLOG - 1; i >= 0; i--)
if(deep[d[x][i]] >= deep[y]) x = d[x][i];
if(x == y) return x;
for(int i = MAXLOG - 1; i >= 0; i--)
if(d[x][i] != d[y][i]) {
x = d[x][i]; y = d[y][i];
}
return d[x][0];
}
int jump(int pos, int step) {//跳到pos的第step个父亲
int dep = deep[pos] - step;
for(int i = MAXLOG - 1; i >= 0; i--)
if(deep[d[pos][i]] >= dep) pos = d[pos][i];
return pos;
}
int solve(int x, int y, int z) {
int a = get_lca(x, y);
int b = get_lca(x, z);
int kind1, kind2, aa1, aa2, bb1, bb2;//kind若为1,表示选中这颗子树,2表示选中除这颗子树外的数
//aa1,aa2,bb1,bb2为区间左右端点
//下面这部分一定要注意细节啊啊啊啊啊啊啊
if(deep[x] - deep[a] >= deep[y] - deep[a]) {//求左右端点
kind1 = 1;
aa1 = jump(x, deep[x] - deep[a] - 1 - (deep[x] - deep[y]) / 2);//求要处理的子树的根节点
aa2 = dfn[aa1] + size[aa1] - 1;
aa1 = dfn[aa1];
} else {
kind1 = 2;
aa1 = jump(y, deep[y] - deep[a] - (deep[y] - deep[x] + 1) / 2);
aa2 = dfn[aa1] + size[aa1];
aa1 = dfn[aa1] - 1;
}
if(deep[x] - deep[b] >= deep[z] - deep[b]) {
kind2 = 1;
bb1 = jump(x, deep[x] - deep[b] - 1 - (deep[x] - deep[z]) / 2);
bb2 = dfn[bb1] + size[bb1] - 1;
bb1 = dfn[bb1];
} else {
kind2 = 2;
bb1 = jump(z, deep[z] - deep[b] - (deep[z] - deep[x] + 1) / 2);
bb2 = dfn[bb1] + size[bb1];
bb1 = dfn[bb1] - 1;
}
if(kind1 == 1 && kind2 == 1) {//一波大力分类讨论,一定要注意每一部分是6种而并非4个
if(aa2 < bb1) return 0;
if(bb2 < aa1) return 0;
if(aa1 <= bb1 && bb2 <= aa2) return bb2 - bb1 + 1;
if(aa1 <= bb1 && aa2 < bb2) return aa2 - bb1 + 1;
if(bb1 < aa1 && bb2 <= aa2) return bb2 - aa1 + 1;
if(bb1 < aa1 && aa2 < bb2) return aa2 - aa1 + 1;
}
if(kind1 == 1 && kind2 == 2) {
if(aa2 <= bb1) return aa2 - aa1 + 1;
if(aa1 >= bb2) return aa2 - aa1 + 1;
if(aa1 <= bb1 && bb2 <= aa2) return bb1 - aa1 + 1 + aa2 - bb2 + 1;
if(aa1 <= bb1 && aa2 < bb2) return bb1 - aa1 + 1;
if(bb1 < aa1 && bb2 <= aa2) return aa2 - bb2 + 1;
if(bb1 < aa1 && aa2 < bb2) return 0;
}
if(kind1 == 2 && kind2 == 1) {
if(bb2 <= aa1) return bb2 - bb1 + 1;
if(bb1 >= aa2) return bb2 - bb1 + 1;
if(aa1 < bb1 && bb2 < aa2) return 0;
if(aa1 < bb1 && aa2 <= bb2) return bb2 - aa2 + 1;
if(bb1 <= aa1 && bb2 < aa2) return aa1 - bb1 + 1;
if(bb1 <= aa1 && aa2 <= bb2) return aa1 - bb1 + 1 + bb2 - aa2 + 1;
}
if(kind1 == 2 && kind2 == 2) {
int t = 0;
if(aa1 <= bb1) t = aa1; else t = bb1;
if(aa2 >= bb2) t += n - aa2 + 1; else t += n - bb2 + 1;
if(aa1 >= bb2) t += aa1 - bb2 + 1;
if(bb1 >= aa2) t += bb1 - aa2 + 1;
return t;
}
}
void work() {
clean_up();
scanf("%d", &n);
for(int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
build_tree(1, 1);//建树
scanf("%d", &m);
for(int i = 1; i <= m; i++) {
scanf("%d%d%d", &x, &y, &z);
printf("%d %d %d\n", solve(x, y, z), solve(y, x, z), solve(z, x, y));
}
return;
}
int main() {
int t;
scanf("%d", &t);
for(int i = 1; i <= t; i++) work();
return 0;
}