23.8.3 杭电暑期多校6部分题解

1004 - Tree

题目大意

有一棵树,每个节点有一个颜色可以为 a ,   b ,   c a,\space b,\space c a, b, c,需要查询树上有多少条简单路径满足 a ,   b ,   c a,\space b,\space c a, b, c 的数量相等

解题思路

考虑一个和哈希很像的想法,假设三个颜色都有一个权值,只要和为零就表示路径上三者数量相等

只要两两之间很大且互质即可,不妨设 a a a 998244353 998244353 998244353 b b b 1 e 9 + 7 1e9+7 1e9+7 c c c 为两者之和的相反数

然后就转化为了点分治问题(也可以用dsu on tree)

反复找树的重心,然后统计经过重心的合法路径,将重心删除后在剩下的子树里继续重复操作即可

统计方法可以将每条延伸出去的链值记录,然后在之前的搜过的非同一子树中的链找是否有答案累加即可

当前子树的链值用vector暂时记录,这颗子树遍历完后将相反数统计进map里即为需要寻找的能产生贡献的目标链值

注意在搜到链值为 0 0 0 时也需要对答案记录 1 1 1 的贡献

code

#include 
using namespace std;
const int N = 1e5 + 9;
const long long X[3] = {998244353, 1000000007, -1000000007 - 998244353};
struct lol {int x, y;} e[N << 1];
int t, n, a[N], ans, top[N], siz[N], num, vis[N], rt, h[N];
long long sum;
vector <long long> vct;
unordered_map <long long, int> g;
void ein(int x, int y) {
    e[++ ans].x = top[x];
    e[ans].y = y;
    top[x] = ans;
}
void dfs(int x, int fa) {
    siz[x] = 1; h[x] = 0;
    for (int i = top[x]; i; i = e[i].x) {
        int y = e[i].y;
        if (y == fa || vis[y]) continue;
        dfs(y, x);
        siz[x] += siz[y];
        h[x] = max(h[x], siz[y]);
    }
    h[x] = max(h[x], num - siz[x]);
    if (h[x] < h[rt]) rt = x;
}
void dfs2(int x, int fa, long long d, int p) {
    d += X[a[x]];
    vct.push_back(d);
    long long dis = d + X[p];
    if (dis == 0) ++ sum;
    if (g.find(-dis) != g.end()) sum += g[-dis];
    for (int i = top[x]; i; i = e[i].x) {
        int y = e[i].y;
        if (y == fa || vis[y]) continue;
        dfs2(y, x, d, p);
    }
}
void dfs1(int x) {
    vis[x] = 1; g.clear(); 
    for (int i = top[x]; i; i = e[i].x) {
        int y = e[i].y;
        if (vis[y]) continue;
        dfs2(y, x, 0, a[x]);
        for (auto v : vct) ++ g[v];
        vct.clear();
    }
    for (int i = top[x]; i; i = e[i].x) {
        int y = e[i].y;
        if (vis[y]) continue;
        rt = 0; num = siz[y];
        dfs(y, x);
        dfs1(rt);
    }
}
int main() {
    scanf("%d", &n); num = n; h[0] = 1e9;
    for (int i = 1; i <= n; ++ i) {
        char c = getchar();
        while (c == ' ' || c == '\n') c = getchar();
        a[i] = c - 'a';
    }
    for (int i = 1, u, v; i < n; ++ i)
        scanf("%d%d", &u, &v), ein(u, v), ein(v, u);
    dfs(1, 0);
    dfs1(rt);
    printf("%lld", sum);
    return 0;
}

你可能感兴趣的:(题解,点分治,数学)