点分树教程:参考博客
题意:
给定一棵 n n n个节点的树,初始时 1 1 1号节点为红色,其余为蓝色。
要求支持如下操作:
将一个节点变为红色。
询问节点 u u u到最近红色节点的距离。
共 q q q次操作。
1 ≤ n , q ≤ 1 0 5 1 \le n, q \le 10 ^5 1≤n,q≤105
点分树:将点分治时的重心按分治的层级连成一颗树,根据重心的性质树高不会超过 l o g n log \ n log n
于是暴力就变成log的了
对某个结点进行修改的话,影响的是他到根的一段log长的路径
回到这题
先建出点分树。
根据性质,点分树上(u,v)两点的lca一定在原树中(u,v)的路径上
这就意味着原树中的路径都能以他们两个端点在点分树中的lca为终点,分解为两条路径。
于是,用一个ans数组记录点分树中每个结点到它点分树子树中最近红色结点的距离。
查询某个点到最近红色结点的距离时,只要查询它在点分树中的祖先的ans+它和点分树中祖先在原树中的距离即可
修改操作类似,每次修改要更新它在点分树中的所有祖先的ans。
#include
using namespace std;
const int maxn = 1e5 + 7, inf = 0x3f3f3f3f;
vector<int> adj[maxn];
int n, m;
int sz[maxn], siz[maxn], vis[maxn], cdfa[maxn], dep[maxn], top[maxn], son[maxn], ans[maxn], fa[maxn];
void dfs(int u) {
siz[u] = 1;
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (v != fa[u]) {
dep[v] = dep[u] + 1;
fa[v] = u;
dfs(v);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
}
void dfs2(int u, int t) {
top[u] = t;
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (v == fa[u]) continue;
if (v == son[u]) dfs2(v, t);
else dfs2(v, v);
}
}
int lca(int u, int v) {
while (top[u] ^ top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
}
void getsz(int u, int p) {
sz[u] = 1;
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (!vis[v] && v != p) {
getsz(v, u);
sz[u] += sz[v];
}
}
}
int getcd(int u, int p, int subsz) {
for (int i = 0; i < adj[u].size(); i++) {
int v = adj[u][i];
if (!vis[v] && v != p && sz[v] > subsz / 2)
return getcd(v, u, subsz);
}
return u;
}
void deco(int u, int p) {
getsz(u, p);
int cd = getcd(u, p, sz[u]);
cdfa[cd] = p;
//printf("cdfa[%d]=%d\n", cd, p);
vis[cd] = 1;
for (int i = 0; i < adj[cd].size(); i++) {
int v = adj[cd][i];
if (v != p && !vis[v])
deco(v, cd);
}
}
void update(int u) {
int v = u;
while (v) {
ans[v] = min(ans[v], dep[u] + dep[v] - 2 * dep[lca(u, v)]);
v = cdfa[v];
}
}
int query(int u) {
int res = inf, v = u;
while (v) {
res = min(res, ans[v] + dep[u] + dep[v] - 2 * dep[lca(u, v)]);
v = cdfa[v];
//printf("v == %d\n", v);
}
return res;
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(1);
dfs2(1, 1);
deco(1, 0);
memset(ans, 0x3f, sizeof(ans));
update(1);
//puts("err");
for (int i = 1; i <= m; i++) {
int t, v;
cin >> t >> v;
if (t == 1) {
update(v);
} else {
cout << query(v) << "\n";
}
}
}