更好的阅读体验
在我认为,这个并不能说单独列出来成为一个算法,更恰当的说,是一种思想、技巧。反正挺简单的,也很有趣(谁会拒绝一个优美的暴力呢),所以写篇笔记记录一手。
dsu 一般指“disjoint set union”,即并查集。那么 dsu on tree 也就是指树上的合并和查询操作。
但是 dsu on tree 的实现却跟普通并查集没有太大联系。共同点仅在于功能上都能合并集合、查询。
dsu on tree 可称为树上启发式合并,是一种优美的暴力,合并子树的时候,把轻儿子合并到重儿子上去。
由于保存合并结果的是一个全局数组。所以每次计算新的字数时,都需要清空。我们可以先计算轻儿子,把重儿子留到最后,重儿子可以不用清零,直接把重儿子的信息拿去计算父亲。
这样一来,在暴力的基础上,将重儿子留到了最后,少算了一次重儿子,时间可以来到优秀的 O ( n log n ) O(n\log n) O(nlogn)(暴力是纯粹的 O ( n 2 ) O(n^2) O(n2))。
题目:给一棵根为1的树,每次询问子树颜色种类数。
void update(int x, int f, int flg){
cnt[col[x]] += flg;
if(cnt[col[x]] == 0 && flg == -1) cols--;
if(cnt[col[x]] == 1 && flg == 1) cols++;
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != f) update(tv, x, flg);
}
}
void dfs(int r, int fa){// 求子树r中的信息, fa为r的父亲
for(int i = fir[r]; i; i = es[i].nxt){ // 遍历r的邻接点
int tv = es[i].v;
if(tv != fa) dfs(tv, r);
}
update(r, fa, 1);
ans[r] = cols;
update(r, fa, -1);
}
直接 O ( n 2 ) O(n^2) O(n2) T 飞。
当然你也可以把 dfs
写成这样(更接近 dsu 的打法):
void dfs(int r, int fa){// 求子树r中的信息, fa为r的父亲
for(int i = fir[r]; i; i = es[i].nxt){ // 遍历r的邻接点
int tv = es[i].v;
if(tv != fa){
dfs(tv, r);
update(tv, r, -1);
}
}
update(r, fa, 1);
ans[r] = cols;
}
现在我们来优化一手暴力。先预处理出轻、重儿子,然后 dfs
轻儿子、再 dfs
重儿子。
void dfs(int x, int f){
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != son[x] && tv != f){
dfs(tv, x);
update(tv, x, -1);
}
}
if(son[x]) dfs(son[x], x);
cnt[col[x]]++;
if(cnt[col[x]] == 1) cols++;
for(int i = fir[x]; i; i = es[i].nxt){
int tv = es[i].v;
if(tv != son[x] && tv != f){
update(tv, x, 1);
}
}
ans[x] = cols;
}
时间复杂度来到优秀的 O ( n log n ) O(n\log n) O(nlogn) !!
但是为什么呢?
因为根据轻重链划分的思想,任何一条到根的路径上,轻边不会超过 log n \log n logn 条,重链是被轻边分隔的,数量也不会超过 log n \log n logn 条。
每棵子树到父亲的边为轻边,做一次 update
,最多做 log n \log n logn 次。
一次 update
可以看作是轻儿子想重儿子的合并操作。
每个节点最多合并 log n \log n logn 次,总的时间复杂度为 O ( n log n ) O(n\log n) O(nlogn) 次。
code
#include
using namespace std;
#define MAXN 100005
int n, m, ecnt, cols, col[MAXN], fir[MAXN], sz[MAXN];
int son[MAXN], cnt[MAXN], ans[MAXN];
struct edge
{
int v, nxt;
} es[MAXN << 1];
void adde(int a, int b)
{
es[++ecnt].v = b, es[ecnt].nxt = fir[a], fir[a] = ecnt;
es[++ecnt].v = a, es[ecnt].nxt = fir[b], fir[b] = ecnt;
}
void dfs1(int x, int f)
{
sz[x]++;
int maxz = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
dfs1(tv, x);
if (sz[tv] > maxz)
maxz = sz[tv], son[x] = tv;
sz[x] += sz[tv];
}
}
}
void update(int x, int f, int flg)
{
cnt[col[x]] += flg;
if (cnt[col[x]] == 0 && flg == -1)
cols--;
if (cnt[col[x]] == 1 && flg == 1)
cols++;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
update(tv, x, flg);
}
}
void dfs2(int x, int f)
{
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
dfs2(tv, x), update(tv, x, -1);
}
if (son[x])
dfs2(son[x], x);
cnt[col[x]]++;
if (cnt[col[x]] == 1)
cols++;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
update(tv, x, 1);
}
ans[x] = cols;
}
int main()
{
int a, b;
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
scanf("%d %d", &a, &b);
adde(a, b);
}
for (int i = 1; i <= n; i++)
scanf("%d", &col[i]);
dfs1(1, 0);
dfs2(1, 0);
scanf("%d", &m);
for (int i = 1; i <= m; i++)
{
scanf("%d", &a);
printf("ans: %d\n", ans[a]);
}
return 0;
}
这个就是板题了,也可以线段树合并去做。但是 dsu on tree 明显更短,更好打。
code
#include
using namespace std;
#define MAXN 100025
#define LL long long int
LL sum, ans[MAXN];
int n, m, ecnt, maxcnt, col[MAXN], fir[MAXN], sz[MAXN], son[MAXN], cnt[MAXN];
struct edge
{
int v, nxt;
} es[MAXN << 1];
void adde(int a, int b)
{
es[++ecnt].v = b, es[ecnt].nxt = fir[a], fir[a] = ecnt;
es[++ecnt].v = a, es[ecnt].nxt = fir[b], fir[b] = ecnt;
}
void dfs1(int x, int f)
{
sz[x]++;
int maxz = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
dfs1(tv, x);
if (sz[tv] > maxz)
maxz = sz[tv], son[x] = tv;
sz[x] += sz[tv];
}
}
}
void update(int x, int f, int flg)
{
cnt[col[x]] += flg;
if (flg == 1)
{
if (cnt[col[x]] > maxcnt)
maxcnt = cnt[col[x]], sum = col[x];
else if (cnt[col[x]] == maxcnt)
sum += col[x];
}
else
maxcnt = 0, sum = 0;
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != f)
{
update(tv, x, flg);
}
}
}
void dfs2(int x, int f)
{
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
dfs2(tv, x), update(tv, x, -1);
}
if (son[x])
dfs2(son[x], x);
cnt[col[x]]++;
if (cnt[col[x]] > maxcnt)
maxcnt = cnt[col[x]], sum = col[x];
else if (cnt[col[x]] == maxcnt)
sum += col[x];
for (int i = fir[x]; i; i = es[i].nxt)
{
int tv = es[i].v;
if (tv != son[x] && tv != f)
update(tv, x, 1);
}
ans[x] = sum;
}
int main()
{
int a, b;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &col[i]);
for (int i = 1; i < n; i++)
scanf("%d %d", &a, &b), adde(a, b);
dfs1(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
return 0;
}