这里没有动态点分治。。
点分治是解决树上问题的一类算法,很多复杂度能从暴力的 O ( n 2 ) O(n^2) O(n2)降低到 O ( n l o g n ) O(nlogn) O(nlogn).
具体做法是就是求一个树的重心,树的重心的性质,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡。
就是说子树的大小可以 l o g n logn logn的降低,这样复杂度就降下来了。
如果要达到 n l o g n nlogn nlogn的复杂度,那么就需要一个 O ( n ) O(n) O(n)的算法的遍历树。我们通过洛谷上面的点分治模板题来了解点分治的一些基本步骤。
P3806
给定一个树,问距离为K的点对是否存在。
首先我们分类一下,点对可以分为这几类。
关于问题1很好解决直接dfs一次就可以了,然后子树与子树组合就可以了。
问题2我们可以通过点分树,化简成一个小的子树的根去解决。
那么问题就回归到解决问题1了,这大概就是点分的思想,具体情况可能还要讨论一下。
我们开始写代码,其中代码最重要的就是求树的重心,这个也很简单这里不讲了。
其实很多点分T的有一种可能就是重心弄错了。
void findrt(int u, int fa) {
sz[u] = 1, son[u] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
findrt(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], S - sz[u]);
if (son[u] < son[rt]) rt = u;
}
然后开始分治,分治就是以重心为根,划分子树,解决,然后在找子树的重心。
void divide(int u) {
vis[u] = pd[0] = 1;
solve(u);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
son[0] = n, S = sz[v], rt = 0;
findrt(v, u);
divide(rt);
}
}
这里面有的写法可能for循环中减去到根的写法,具体看写法了和情况了。
下面我们来看看怎么解决 u u u为根的树的问题。
void get_dis(int u, int fa) {
rev[++tot] = dis[u];
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dis[v] = dis[u] + ed[i].w;
get_dis(v, u);
}
}
int solve(int u) {
int c = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
for (int k = 1; k <= m; k++)
if (que[k] >= rev[j])
ok[k] |= pd[que[k] - rev[j]];
for (int j = 1; j <= tot; j++)
if (rev[j] <= 10000000) q[++c] = rev[j], pd[rev[j]] = 1;
}
for (int i = 1; i <= c; i++) pd[q[i]] = 0;
}
其中我们可以看到get_dis函数就是求u的每一个子树的距离,而且记录下来,因为都经过了根,所以要子树与子树的组合,并且最后要清空,同时注意要记录距离为0的情况。
这样由于m的范围不打,所以总的时间复杂度 O ( n m l o g n ) O(nmlogn) O(nmlogn)是可行的。
接下来讲一讲洛谷上面的几道简单点分题。
P2634
这道题和上面一道题思路差不多,这不就是solve函数中,需要统计余数为0,1,2的和。累加ans
ll solve(int u) {
ll cc = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++) {
ll x = rev[j] % 3ll;
if (x == 0) cc += mp[0];
else cc += mp[3 - x];
}
for (int j = 1; j <= tot; j++)
mp[rev[j] % 3ll]++;
}
mp[0] = mp[1] = mp[2] = 0;
return cc;
}
注意这里得到的ans还需要乘2加n才是最终的分子。
P4149
这道题也和上面的题相似。
做法,在solve中需要一个桶记录到某个距离边权的最小值,然后和第一道题类似方法处理。
void solve(int u) {
num[0] = dep[u] = 0;
int c = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
ans = min(ans, num[k - rev[j]] + cc[j]);
for (int j = 1; j <= tot; j++)
if (rev[j] <= k)
q[++c] = rev[j], num[rev[j]] = min(num[rev[j]], cc[j]);
}
for (int i = 1; i <= c; i++) num[q[i]] = inf;
}
CF161D Distance in Tree
这道题,其实哪儿练一练模板还是可以的,也很简单和上面的差不多,当面 n*m的dp也能过。
P4178
这道题和上面的题基本上一样,只需要改成统计小于的数字就可以了,至于统计与更新用树状数组就可以了。
void solve(int u) {
int c = 0;
add(1, 1);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
ans += sum(k - rev[j] + 1);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
q[++c] = rev[j], add(rev[j] + 1, 1);
}
add(1, -1);
for (int i = 1; i <= c; i++) add(q[i] + 1, -1);
P2664
这道题就很难了,想不到,看题解。
大概做法就是ans数组记录每一个点的答案,然后分成每一个点分树求对树中每一个点的贡献。
一个子树中对于根的贡献就是,每一个颜色从他到根的路径上第一次出现那么,他对根的贡献就是他的子树的大小。
对于其他点。
到u到根的这段路径上的不同颜色为num,那么对于点u的贡献就是 n u m ∗ ( s z [ r t ] − s z [ u ] ) num*(sz[rt]-sz[u]) num∗(sz[rt]−sz[u]),这个下来可以自己想一想。
然后就是点分树上除去u所在子树的其他点的不同颜色,其实就是 ( s u m − s u m u ) (sum-sum_{u}) (sum−sumu)
因此这样就可以写了。但不过你思路明白了可能代码写起来还不好写。
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
#define SZ(x) ((int)x.size())
#define ll long long
const int maxn = 100000 + 10;
const ll inf = 1e18;
struct edge {
int u, v, nxt;
} ed[maxn << 1];
int head[maxn << 1], cnt;
void add_e(int u, int v) {
ed[++cnt] = edge{u, v, head[u]};
head[u] = cnt;
}
int sz[maxn], son[maxn], vis[maxn], col[maxn], mx, rt, n, m, S;
void findrt(int u, int fa) {
sz[u] = 1, son[u] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
findrt(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], S - sz[u]);
if (son[u] < son[rt]) rt = u;
}
ll ans[maxn], c[maxn], mp[maxn], sum, num, tmp;
void dfs1(int u, int fa) {
mp[col[u]]++, sz[u] = 1;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
dfs1(v, u);
sz[u] += sz[v];
}
if (mp[col[u]] == 1) {
sum += sz[u];
c[col[u]] += sz[u];
}
mp[col[u]]--;
}
void dfs(int u, int fa, int f) {
mp[col[u]]++;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dfs(v, u, f);
}
if (mp[col[u]] == 1) {
sum += sz[u] * f;
c[col[u]] += sz[u] * f;
}
mp[col[u]]--;
}
void dfs2(int u, int fa) {
mp[col[u]]++;
if (mp[col[u]] == 1) {
sum -= c[col[u]];
num++;
}
ans[u] += sum + num * tmp;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dfs2(v, u);
}
if (mp[col[u]] == 1) {
sum += c[col[u]];
num--;
}
mp[col[u]]--;
}
void clear(int u, int fa) {
c[col[u]] = mp[col[u]] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
clear(v, u);
}
}
void solve(int u) {
dfs1(u, 0);
ans[u] += sum;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
mp[col[u]]++, sum -= sz[v], c[col[u]] -= sz[v];
dfs(v, u, -1);
mp[col[u]]--;
tmp = sz[u] - sz[v];
dfs2(v, u);
mp[col[u]]++, sum += sz[v], c[col[u]] += sz[v];
dfs(v, u, 1);
mp[col[u]]--;
}
sum = num = 0;
clear(u, 0);
}
void divide(int u) {
vis[u] = 1;
solve(u);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
S = sz[v];
son[rt = 0] = n;
findrt(v, 0);
divide(rt);
}
}
int main() {
n = read();
for (int i = 1; i <= n; i++)
col[i] = read();
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
add_e(u, v);
add_e(v, u);
}
S = son[rt = 0] = n;
findrt(1, 0);
divide(rt);
for (int i = 1; i <= n; i++) {
cout << ans[i] << endl;
}
return 0;
}