解法:求树上两点路径长度,其实就是 d e p [ u ] + d e p [ v ] − 2 ∗ d e p [ l c a ] dep[u] +dep[v]-2*dep[lca] dep[u]+dep[v]−2∗dep[lca],这里我们采用重链剖分的方式求lca,对于每次修改,都会改变树的结构,我们可以用一颗线段树维护子树的top(重链剖分的链头),这个很简单,懂树剖都会写求lca,怎么求深度?如果当前节点所在的子树没有被修改,那么就是原 d e p dep dep,如果被修改了,我们通过线段树找到节点所在的最大的被修改的子树,然后用可持久化线段树查 u u u在该子树中排第几,就能得到该节点的深度,然后水题啦
#include
using namespace std;
const int maxn = 1e5 + 10;
int rt[maxn], ls[maxn * 20], rs[maxn * 20], sum[maxn * 20], cnt, n;
int Set[maxn * 4], L[maxn], R[maxn], id, son[maxn], sz[maxn], Top[maxn];
int f[maxn], dep[maxn];
vector<int> G[maxn];
#define mid (l + r) / 2
void up(int &o, int pre, int l, int r, int k) {
o = ++cnt;
ls[o] = ls[pre];
rs[o] = rs[pre];
sum[o] = sum[pre] + 1;
if (l == r)
return;
if (k <= mid)
up(ls[o], ls[pre], l, mid, k);
else
up(rs[o], rs[pre], mid + 1, r, k);
}
int qu(int o, int pre, int l, int r, int ql, int qr) {
if (l >= ql && r <= qr)
return sum[o] - sum[pre];
int res = 0;
if (ql <= mid)
res += qu(ls[o], ls[pre], l, mid, ql, qr);
if (qr > mid)
res += qu(rs[o], rs[pre], mid + 1, r, ql, qr);
return res;
}
void update(int o, int l, int r, int ql, int qr, int v) {
if (l >= ql && r <= qr) {
Set[o] = v;
return;
}
if (ql <= mid)
update(o * 2, l, mid, ql, qr, v);
if (qr > mid)
update(o * 2 + 1, mid + 1, r, ql, qr, v);
}
int query(int o, int l, int r, int k) {
if (Set[o] || l == r)
return Set[o];
if (k <= mid)
return query(o * 2, l, mid, k);
return query(o * 2 + 1, mid + 1, r, k);
}
void dfs1(int u, int fa) {
L[u] = ++id;
up(rt[id], rt[id - 1], 1, n, u);
sz[u] = 1;
son[u] = 0;
dep[u] = dep[fa] + 1;
f[u] = fa;
for (auto v : G[u])
if (v != fa) {
dfs1(v, u);
sz[u] += sz[v];
if (sz[son[u]] < sz[v])
son[u] = v;
}
R[u] = id;
}
void dfs2(int u, int top) {
Top[u] = top;
if (son[u])
dfs2(son[u], top);
for (auto v : G[u])
if (v != f[u] && v != son[u])
dfs2(v, v);
}
int calc(int u) {
int tmp = query(1, 1, n, L[u]);
if (!tmp)
return dep[u];
return dep[f[tmp]] + qu(rt[R[tmp]], rt[L[tmp] - 1], 1, n, u, n);
}
int LCA(int u, int v) {
while (Top[u] != Top[v]) {
if (dep[Top[u]] < dep[Top[v]])
swap(u, v);
u = f[Top[u]];
}
if (dep[u] > dep[v])
swap(u, v);
return u;
}
int gao(int u, int v) {
int lca = LCA(u, v);
int o = query(1, 1, n, L[lca]);
if (o != 0) {
int ans = qu(rt[R[o]], rt[L[o] - 1], 1, n, u, n) - qu(rt[R[o]], rt[L[o] - 1], 1, n, v, n);
return abs(ans);
}
return calc(u) + calc(v) - calc(lca) * 2;
}
int main()
{
int T;
scanf("%d", &T);
while (T--) {
int u, v, q, opt;
scanf("%d", &n);
cnt = id = 0;
for (int i = 1; i <= n; i++)
G[i].clear();
for (int i = 1; i <= 4 * n; i++)
Set[i] = 0;
for (int i = 1; i < n ;i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 1);
scanf("%d", &q);
while (q--) {
scanf("%d%d", &opt, &u);
if (opt == 1) {
int cat = query(1, 1, n, L[u]);
if (!cat)
update(1, 1, n, L[u], R[u], u);
}
else {
scanf("%d", &v);
printf("%d\n", gao(u, v));
}
}
}
}