dsu on tree简介及例题


d s u   o n   t r e e dsu ~on ~tree dsu on tree
树上启发式合并,多用于对子树的暴力询问,通过使用轻重链定义来进行优化,将算法复杂度降到 O ( n l o g n ) O(nlogn) O(nlogn)
算是一种优雅的暴力


先用一道dsu on tree比较模版的题来引一下
codeforces 600E
题意:

一棵树n个点,每个点有一个颜色 要求每个结点子树的出现哪个颜色次数最多 如果有多个颜色次数同时最多,结果为这些颜色编号相加

首先可以考虑暴力的写法,对每个结点,暴力统计他的子树每个颜色有多少,如果出现超过答案的进行更新,等于答案的进行相加
这样算下来复杂度应该是 O ( n 2 ) O(n^2) O(n2)
那么这个时候,就可以考虑用dsu on tree进行优化


首先需要知道的定义是重儿子和轻儿子
对于一个结点,他某个儿子的这个子树所包含的结点最多,那么这个儿子就是重儿子,其他的儿子就是轻儿子


原理

引入一个结论:从某一个点到根的路径中轻链的个数不会超过 l o g n logn logn

证明:
从一个结点开始往上到根的过程中,每次交汇到一个点,如果这个点是轻链,那么这个点的子节点数一定比重链的少,假设这个轻链对应子树结点大小为 y y y,这个交汇点子树结点大小为 x x x,那么 y < x / 2 yy<x/2
如果这个点到根上有z个轻链,那么需要交汇z次
那么这个点的子树结点个数会成为 n u m < = n 2 z num<=\frac{n}{2^z} num<=2zn
n u m > = 1 num>=1 num>=1,所以 z < l o g n zz<logn

并且,由于递归的性质,每次在递归return之后会回到一个父亲,那么在返回的时候,我们之前算的子树完全可以保留一个回到父亲节省下次计算,那么,我们每次最后访问重儿子,并且把重儿子的贡献保留传递给父亲继续计算,然后把轻儿子暴力计算,通过这个结论我们可以发现这样的复杂度最后只会达到 O ( n l o g n ) O(nlogn) O(nlogn)


算法

1.通过dfs计算出重儿子,轻儿子使用
2.先dfs轻儿子,然后dfs重儿子
3.计算贡献,如果当前儿子是一个重儿子,那么保留这个贡献,否则将贡献清除

可以简化为如下的模版:

int sz[N], son[N];
vector<int> g[N];
void dfs(int u, int fa) {
    sz[u] = 1;
    for (auto v : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
int Son;
//add函数为计算贡献的方法
void add(int u, int fa, int val) {
    
}
void dfs1(int u, int fa, int op) {
    for (auto v : g[u]) {
        if (v == fa || v == son[u]) continue;
        dfs1(v, u, 0);
    }
    if (son[u]) dfs1(son[u], u, 1), Son = son[u];
    add(u, fa, 1)
    Son = 0;
    if (!op) add(u, fa, -1);
}

这样就形成了一个模版


例题

1. Lomsat gelral(引题)

首先先看刚才的引题,只需要修改这个add函数即可
对每个暴力的点加入他的颜色,用一个cnt函数计算每个颜色的贡献
并且看这个贡献是否达到或超过最大值,进行颜色答案的计算即可
AC代码

/*
    Author : zzugzx
    Lang : C++
    Blog : blog.csdn.net/qq_43756519
*/

#include
using namespace std;

#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(), (x).end()
#define endl '\n'
#define SZ(x) (int)x.size()
#define mem(a, b) memset(a, b, sizeof(a))

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
//const int mod = 1e9 + 7;
const int mod = 998244353;

const double eps = 1e-8;
const double pi = acos(-1.0);
const int maxn = 1e6 + 10;
const int N = 1e2 + 10;
const ll inf = 0x3f3f3f3f;
const ll oo = 8e18;
const int dir[][2]={{0, 1}, {1, 0}, {0, -1}, {-1, 0}, {1, 1}, {1, -1}, {-1, 1}, {-1, -1}};

int col[maxn], sz[maxn], son[maxn];
vector<int> g[maxn];
void dfs(int u, int fa) {
    sz[u] = 1;
    for (auto v : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
ll ans[maxn], sum, cnt[maxn], Son, mx;
void add(int u, int fa, int val) {
    cnt[col[u]] += val;
    if (cnt[col[u]] > mx) {
        mx = cnt[col[u]];
        sum = col[u];
    }
    else if (cnt[col[u]] == mx) {
        sum += col[u];
    }
    for (auto v : g[u]) {
        if (v == fa || v == Son) continue;
        add(v, u, val);
    }

}
void dfs1(int u, int fa, int op) {
    for (auto v : g[u]) {
        if (v == fa || v == son[u]) continue;
        dfs1(v, u, 0);
    }
    if (son[u]) dfs1(son[u], u, 1), Son = son[u];
    add(u, fa, 1);
    Son = 0;
    ans[u] = sum;
    if (!op) add(u, fa, -1), sum = 0, mx = 0;
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
//  freopen("in.txt", "r", stdin);
//  freopen("out.txt", "w", stdout);
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> col[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].pb(v);
        g[v].pb(u);
    }
    dfs(1, 0);
    dfs1(1, 0, 0);
    for (int i = 1; i <= n; i++)
        cout << ans[i] << ' ';
    return 0;
}

2.Tree Requests

题目链接:Codeforces 570D
题意:
一棵树上的每个点含有一个字母
m次询问给你一个深度d和一个结点v,问结点v子树中深度为d的字母是否能够重新排列组成一个回文串
题解:
重新排列组成回文串,只需要奇数字母个数不会超过1即可,所以可以通过这个进行判断,这个每次判断的复杂度是一个常数26
但发现题目有m次访问,那么我们可以统计一下每个结点的所有访问,然后把这个结点的一块计算,并对访问对应的深度进行判断即可
add的函数修改只需要统计当前子树字母即可
AC代码

/*
    Author : zzugzx
    Lang : C++
    Blog : blog.csdn.net/qq_43756519
*/

#include
using namespace std;

#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(), (x).end()
#define endl '\n'
#define SZ(x) (int)x.size()
#define mem(a, b) memset(a, b, sizeof(a))

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
//const int mod = 1e9 + 7;
const int mod = 998244353;

const double eps = 1e-8;
const double pi = acos(-1.0);
const int maxn = 5e5 + 10;
const int N = 1e2 + 10;
const ll inf = 0x3f3f3f3f;
const ll oo = 8e18;
const int dir[][2]={{0, 1}, {1, 0}, {0, -1}, {-1, 0}, {1, 1}, {1, -1}, {-1, 1}, {-1, -1}};

char c[maxn];
int sz[maxn], son[maxn], dep[maxn];
vector<int> g[maxn];
vector<pii> qry[maxn];
void dfs(int u, int fa) {
    sz[u] = 1, dep[u] = dep[fa] + 1;
    for (auto v : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
bool ans[maxn];
int cnt[maxn][26], Son;
void add(int u, int fa, int val) {
    cnt[dep[u]][c[u] - 'a'] += val;
    for (auto v : g[u]) {
        if (v == fa || v == Son) continue;
        add(v, u, val);
    }
}
bool check(int x) {
    int res = 0;
    for (int i = 0; i < 26; i++)
        if (cnt[x][i] & 1) res++;
    return res <= 1;
}
void dfs1(int u, int fa, int op) {
    for (auto v : g[u]) {
        if (v == fa || v == son[u]) continue;
        dfs1(v, u, 0);
    }
    if (son[u]) dfs1(son[u], u, 1), Son = son[u];
    add(u, fa, 1);
    Son = 0;
    for (auto i : qry[u]) {
        int id = i.se, d = i.fi;
        if (check(d)) ans[id] = 1;
        else ans[id] = 0;
    }
    if (!op) add(u, fa, -1);
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
//  freopen("in.txt", "r", stdin);
//  freopen("out.txt", "w", stdout);
    int n, _;
    cin >> n >> _;
    for (int i = 2; i <= n; i++) {
        int x;
        cin >> x;
        g[x].pb(i);
        g[i].pb(x);
    }
    for (int i = 1; i <= n; i++)
        cin >> c[i];
    for (int i = 1; i <= _; i++) {
        int v, h;
        cin >> v >> h;
        qry[v].pb({h, i});
    }
    dfs(1, 0);
    dfs1(1, 0, 0);
    for (int i = 1; i <= _; i++)
        if (ans[i]) cout << "Yes" << endl;
        else cout << "No" << endl;
    return 0;
}

3.Strange Memory

gym102832 2020CCPC长春F
题意:
给一棵大小为n的树,并计算
∑ i = 1 n ∑ j = i + 1 n [ a i   x o r   a j = a l c a ( i , j ) ] ( i   x o r   j ) \sum_{i=1}^{n}\sum_{j=i+1}^{n}[a_i~xor~a_j = a_{lca(i,j)}](i~xor~j) i=1nj=i+1n[ai xor aj=alca(i,j)](i xor j)
题解:
n < = 1 e 5 , a i < = 1 e 6 n<=1e5,a_i<=1e6 n<=1e5,ai<=1e6
对于当时场外看题时想出的做法是用map存子树进行递归,好像这个方法由于是2log被卡掉了,最后下来听方法是dsu on tree所以来学习了这个方法做了这道题
对于这道题lca,我们只需要统计的是一个根的不同子树互相产生的贡献即可,那么这道题就转化成为了子树询问问题,我们每次对一个子树进行查询他和之前已经加入子树之间产生的贡献,然后把这个子树加入
由于 a i < = 1 e 6 这 是 一 个 突 破 口 a_i<=1e6这是一个突破口 ai<=1e6
并且利用
a i   x o r   a j = a l c a ( i , j ) 可 以 转 化 为 a i   x o r   a l c a ( i , j ) = a j a_i~xor~a_j=a_{lca(i,j)}可以转化为a_i~xor~a_{lca(i,j)}=a_j ai xor aj=alca(i,j)ai xor alca(i,j)=aj
这性质
我们把当前访问结点u作为lca,每次访问子树计算cnt[i][j][0/1]分别表示i这个数第j位的0/1的个数
这样每次对于一个新的子树结点v需要计算的贡献就是
c n t [ a v   x o r   a u ] [ i ] [ ! ( v > > i & 1 ) ] ∗ 1 < < i , 0 < i < 20 cnt[a_v~xor~a_u][i][!(v>>i\&1)]*1<cnt[av xor au][i][!(v>>i&1)]1<<i,0<i<20
最终将答案加在一起即可
AC代码

/*
    Author : zzugzx
    Lang : C++
    Blog : blog.csdn.net/qq_43756519
*/

#include
using namespace std;

#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(), (x).end()
#define endl '\n'
#define SZ(x) (int)x.size()
#define mem(a, b) memset(a, b, sizeof(a))

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
//const int mod = 1e9 + 7;
const int mod = 998244353;

const double eps = 1e-8;
const double pi = acos(-1.0);
const int maxn = 1e6 + 10;
const int N = 1e5 + 10;
const ll inf = 0x3f3f3f3f;
const ll oo = 8e18;
const int dir[][2]={{0, 1}, {1, 0}, {0, -1}, {-1, 0}, {1, 1}, {1, -1}, {-1, 1}, {-1, -1}};

int sz[N], son[N], a[N];
vector<int> g[N];
void dfs(int u, int fa) {
    sz[u] = 1;
    for (auto v : g[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
ll ans;
int cnt[1500000][18][2], Son;
void calc(int u, int fa, int w) {
    for (int i = 0; i < 18; i++)
        ans += (1ll << i) * cnt[a[u] ^ w][i][!((u >> i) & 1)];
    for (auto v : g[u]) {
        if (v == fa || v == Son) continue;
        calc(v, u, w);
    }
}
void add(int u, int fa, int val) {
    for (int i = 0; i < 18; i++)
        cnt[a[u]][i][(u >> i) & 1] += val;
    for (auto v : g[u]) {
        if (v == fa || v == Son) continue;
        add(v, u, val);
    }
}
void dfs1(int u, int fa, int op) {
    for (auto v : g[u]) {
        if (v == fa || v == son[u]) continue;
        dfs1(v, u, 0);
    }
    if (son[u]) dfs1(son[u], u, 1), Son = son[u];
    for (auto v : g[u]) {
        if (v == fa || v == Son) continue;
        calc(v, u, a[u]);
        add(v, u, 1);
    }
    for (int i = 0; i < 18; i++)
        cnt[a[u]][i][(u >> i) & 1] += 1;
    Son = 0;
    if (!op) add(u, fa, -1);
}
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
//  freopen("in.txt", "r", stdin);
//  freopen("out.txt", "w", stdout);
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].pb(v);
        g[v].pb(u);
    }
    dfs(1, 0);
    dfs1(1, 0, 0);
    cout << ans << endl;
    return 0;
}

你可能感兴趣的:(dsu on tree简介及例题)