点分治学习模板及一些例题

点分治

这里没有动态点分治。。

点分治是解决树上问题的一类算法,很多复杂度能从暴力的 O ( n 2 ) O(n^2) O(n2)降低到 O ( n l o g n ) O(nlogn) O(nlogn).
具体做法是就是求一个树的重心,树的重心的性质,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡。
就是说子树的大小可以 l o g n logn logn的降低,这样复杂度就降下来了。
如果要达到 n l o g n nlogn nlogn的复杂度,那么就需要一个 O ( n ) O(n) O(n)的算法的遍历树。我们通过洛谷上面的点分治模板题来了解点分治的一些基本步骤。
P3806
给定一个树,问距离为K的点对是否存在。
首先我们分类一下,点对可以分为这几类。

1.经过根的点对。

2.不经过根的点对。

关于问题1很好解决直接dfs一次就可以了,然后子树与子树组合就可以了。
问题2我们可以通过点分树,化简成一个小的子树的根去解决。
那么问题就回归到解决问题1了,这大概就是点分的思想,具体情况可能还要讨论一下。

我们开始写代码,其中代码最重要的就是求树的重心,这个也很简单这里不讲了。

其实很多点分T的有一种可能就是重心弄错了。

void findrt(int u, int fa) {
    sz[u] = 1, son[u] = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v] || v == fa) continue;
        findrt(v, u);
        sz[u] += sz[v];
        son[u] = max(son[u], sz[v]);
    }
    son[u] = max(son[u], S - sz[u]);
    if (son[u] < son[rt]) rt = u;
}

然后开始分治,分治就是以重心为根,划分子树,解决,然后在找子树的重心。

void divide(int u) {
    vis[u] = pd[0] = 1;
    solve(u);
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        son[0] = n, S = sz[v], rt = 0;
        findrt(v, u);
        divide(rt);
    }
}

这里面有的写法可能for循环中减去到根的写法,具体看写法了和情况了。
下面我们来看看怎么解决 u u u为根的树的问题。

void get_dis(int u, int fa) {
    rev[++tot] = dis[u];
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (v == fa || vis[v]) continue;
        dis[v] = dis[u] + ed[i].w;
        get_dis(v, u);
    }
}
int solve(int u) {
    int c = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        tot = 0;
        dis[v] = ed[i].w;
        get_dis(v, u);
        for (int j = 1; j <= tot; j++)
            for (int k = 1; k <= m; k++)
                if (que[k] >= rev[j])
                    ok[k] |= pd[que[k] - rev[j]];
        for (int j = 1; j <= tot; j++)
            if (rev[j] <= 10000000) q[++c] = rev[j], pd[rev[j]] = 1;
    }
    for (int i = 1; i <= c; i++) pd[q[i]] = 0;
}

其中我们可以看到get_dis函数就是求u的每一个子树的距离,而且记录下来,因为都经过了根,所以要子树与子树的组合,并且最后要清空,同时注意要记录距离为0的情况。
这样由于m的范围不打,所以总的时间复杂度 O ( n m l o g n ) O(nmlogn) Onmlogn)是可行的。
接下来讲一讲洛谷上面的几道简单点分题。
P2634
这道题和上面一道题思路差不多,这不就是solve函数中,需要统计余数为0,1,2的和。累加ans

ll solve(int u) {
    ll cc = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        tot = 0;
        dis[v] = ed[i].w;
        get_dis(v, u);
        for (int j = 1; j <= tot; j++) {
            ll x = rev[j] % 3ll;
            if (x == 0) cc += mp[0];
            else cc += mp[3 - x];
        }
        for (int j = 1; j <= tot; j++)
            mp[rev[j] % 3ll]++;
    }
    mp[0] = mp[1] = mp[2] = 0;
    return cc;
}

注意这里得到的ans还需要乘2加n才是最终的分子。
P4149
这道题也和上面的题相似。
做法,在solve中需要一个桶记录到某个距离边权的最小值,然后和第一道题类似方法处理。

void solve(int u) {
    num[0] = dep[u] = 0;
    int c = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        tot = 0;
        dis[v] = ed[i].w;
        get_dis(v, u);
        for (int j = 1; j <= tot; j++)
            if (k >= rev[j])
                ans = min(ans, num[k - rev[j]] + cc[j]);
        for (int j = 1; j <= tot; j++)
            if (rev[j] <= k)
                q[++c] = rev[j], num[rev[j]] = min(num[rev[j]], cc[j]);
    }
    for (int i = 1; i <= c; i++) num[q[i]] = inf;
}

CF161D Distance in Tree
这道题,其实哪儿练一练模板还是可以的,也很简单和上面的差不多,当面 n*m的dp也能过。

P4178
这道题和上面的题基本上一样,只需要改成统计小于的数字就可以了,至于统计与更新用树状数组就可以了。

void solve(int u) {
    int c = 0;
    add(1, 1);
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        tot = 0;
        dis[v] = ed[i].w;
        get_dis(v, u);
        for (int j = 1; j <= tot; j++)
            if (k >= rev[j])
                ans += sum(k - rev[j] + 1);
        for (int j = 1; j <= tot; j++)
            if (k >= rev[j])
                q[++c] = rev[j], add(rev[j] + 1, 1);
    }
    add(1, -1);
    for (int i = 1; i <= c; i++) add(q[i] + 1, -1);

P2664
这道题就很难了,想不到,看题解。
大概做法就是ans数组记录每一个点的答案,然后分成每一个点分树求对树中每一个点的贡献。
一个子树中对于根的贡献就是,每一个颜色从他到根的路径上第一次出现那么,他对根的贡献就是他的子树的大小。
对于其他点。
到u到根的这段路径上的不同颜色为num,那么对于点u的贡献就是 n u m ∗ ( s z [ r t ] − s z [ u ] ) num*(sz[rt]-sz[u]) num(sz[rt]sz[u]),这个下来可以自己想一想。
然后就是点分树上除去u所在子树的其他点的不同颜色,其实就是 ( s u m − s u m u ) (sum-sum_{u}) (sumsumu)
因此这样就可以写了。但不过你思路明白了可能代码写起来还不好写。

#include "bits/stdc++.h"

using namespace std;
inline int read() {
    int x = 0;
    bool f = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    if (f) return x;
    return 0 - x;
}
#define SZ(x) ((int)x.size())
#define ll long long

const int maxn = 100000 + 10;
const ll inf = 1e18;
struct edge {
    int u, v, nxt;
} ed[maxn << 1];
int head[maxn << 1], cnt;
void add_e(int u, int v) {
    ed[++cnt] = edge{u, v, head[u]};
    head[u] = cnt;
}
int sz[maxn], son[maxn], vis[maxn], col[maxn], mx, rt, n, m, S;

void findrt(int u, int fa) {
    sz[u] = 1, son[u] = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v] || v == fa) continue;
        findrt(v, u);
        sz[u] += sz[v];
        son[u] = max(son[u], sz[v]);
    }
    son[u] = max(son[u], S - sz[u]);
    if (son[u] < son[rt]) rt = u;
}
ll ans[maxn], c[maxn], mp[maxn], sum, num, tmp;

void dfs1(int u, int fa) {
    mp[col[u]]++, sz[u] = 1;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v] || v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v];
    }
    if (mp[col[u]] == 1) {
        sum += sz[u];
        c[col[u]] += sz[u];
    }
    mp[col[u]]--;
}

void dfs(int u, int fa, int f) {
    mp[col[u]]++;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (v == fa || vis[v]) continue;
        dfs(v, u, f);
    }
    if (mp[col[u]] == 1) {
        sum += sz[u] * f;
        c[col[u]] += sz[u] * f;
    }
    mp[col[u]]--;
}

void dfs2(int u, int fa) {
    mp[col[u]]++;
    if (mp[col[u]] == 1) {
        sum -= c[col[u]];
        num++;
    }
    ans[u] += sum + num * tmp;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (v == fa || vis[v]) continue;
        dfs2(v, u);
    }
    if (mp[col[u]] == 1) {
        sum += c[col[u]];
        num--;
    }
    mp[col[u]]--;
}

void clear(int u, int fa) {
    c[col[u]] = mp[col[u]] = 0;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (v == fa || vis[v]) continue;
        clear(v, u);
    }
}


void solve(int u) {
    dfs1(u, 0);
    ans[u] += sum;
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        mp[col[u]]++, sum -= sz[v], c[col[u]] -= sz[v];
        dfs(v, u, -1);
        mp[col[u]]--;
        tmp = sz[u] - sz[v];
        dfs2(v, u);
        mp[col[u]]++, sum += sz[v], c[col[u]] += sz[v];
        dfs(v, u, 1);
        mp[col[u]]--;
    }
    sum = num = 0;
    clear(u, 0);

}

void divide(int u) {
    vis[u] = 1;
    solve(u);
    for (int i = head[u]; i; i = ed[i].nxt) {
        int v = ed[i].v;
        if (vis[v]) continue;
        S = sz[v];
        son[rt = 0] = n;
        findrt(v, 0);
        divide(rt);
    }
}

int main() {
    n = read();
    for (int i = 1; i <= n; i++)
        col[i] = read();
    for (int i = 1, u, v; i < n; i++) {
        u = read(), v = read();
        add_e(u, v);
        add_e(v, u);
    }
    S = son[rt = 0] = n;
    findrt(1, 0);
    divide(rt);
    for (int i = 1; i <= n; i++) {
        cout << ans[i] << endl;
    }
    return 0;
}

你可能感兴趣的:(ACM题解)