codeforces Good bye 2023 E

捋一捋

analysis

  • 考虑遍历每一个节点,以每个节点作为 l c a lca lca 思考。
  • 当前节点为 l c a lca lca 那么要想答案更大肯定是从不同子树(不同子树满足 l c a lca lca)中各选择一个节点到 l c a lca lca 不同颜色最多,假设 c i ci ci 为每个节点到当前节点不同颜色的数量,那么就要选择每个子树中最大的 c i ci ci ,然后选出最大值和次大值。
  • 考虑使用 d f s dfs dfs ,然后在回溯的过程中更新节点和答案。

problems

  • 如何找出子树的最大 c i ci ci
  • 如何更新一个节点的贡献?
  • 在回溯的过程中遇见颜色一致的节点怎样做到不重不漏?

solutions

  • 对于最大 c i ci ci 很容易想到 R M Q RMQ RMQ 的做法,在这里可以采用 d f s dfs dfs 序结合一个数据结构。
  • 如何更新一个节点的贡献,因为 c i ci ci 是向上到每一个节点的不同颜色的数量,所以当当前节点更新答案过后就根据我们 d f s dfs dfs 的出来的区间进行区间加 1 1 1 。(子节点一定会经过这个节点)
  • 对于回溯,因为是回溯,所以很容易想到更新是一个自下而上的过程。所以我们应该考虑的是祖先节点与孙子节点(当然也可能是父节点与儿子节点)颜色一致的情况。上面对于节点的更新遇见孙子节点一致的话,我们不妨用一个 s e t set set 装入每个颜色对应的节点的 d f s dfs dfs 序。孙子节点的 d f s dfs dfs 序一定大于祖先节点。所以在更新答案之前我们可以先将颜色一致的孙子节点进行区间减,然后在更新答案。
  • 综上所述数据结构可食用 线段树
  • 每个节点至多进行一次区间加,区间减和 s e t set set 的插入删除。时间复杂度 O ( n l o g n ) O(nlogn) O(nlogn)

Think Twice, Code once

#include 
#define il inline
#define get getchar
#define put putchar
#define is isdigit
#define int long long
#define dfor(i, a, b) for(int i = a; i <= b; ++i)
#define dforr(i, a, b) for(int i = a; i >= b; --i)
#define dforn(i, a, b) for(int i = a; i <= b; ++i, put(10))
#define mem(a, b, c) memset(a, b, c)
#define memc(a, b) memcpy(a, b, sizeof (a))
#define pr 114514191981
#define gg(a) cout << a, put(32)
#define INF 0x7fffffff
#define tf(x) cout << '\n' << "-> " << x << " <-" << '\n';
#define endl '\n'
#define ls i << 1
#define rs i << 1 | 1
#define la(r) tr[r].ch[0]
#define ra(r) tr[r].ch[1]
#define lowbit(x) (x & -x)
#define ct cin.tie(nullptr),ios_base::sync_with_stdio(false)
using namespace std;
typedef unsigned int ull;
typedef pair<int, int> pii;
int read(void) {
    int x = 0, f = 1; char c = get();
    while(!is(c)) (f = c == 45? -1: 1), c = get();
    while(is(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = get();
    return x * f;
}
void write(int x) {
    if (x < 0) x = -x, put(45);
    if (x > 9) write(x / 10);
    put((x % 10) ^ 48);
}
#define writeln(a) write(a), put(10)
#define writesp(a) write(a), put(32)
#define writessp(a) put(32), write(a)
const int N = 3e5 + 10, M = 2e5 + 10, SN = 1e3 + 10, mod = 1e9 + 9, MOD = 998244353;
int tot, ans, a[N], ed[N], re[N], dfn[N], head[N];
vector<pii> e(N);
set<int, greater<int>> s[N];
struct p {
    int l, r, Max, tag;
}tr[N << 2];
void build(int i, int l, int r) {
    tr[i] = {l, r, 0, 0};
    if (l == r) return ;
    int mid = (l + r) >> 1;
    build(ls, l, mid), build(rs, mid + 1, r);
}
void pushup(int i) {
    tr[i].Max = max(tr[ls].Max, tr[rs].Max);
}
void pushdown(int i) {
    if (tr[i].tag) {
        tr[ls].Max += tr[i].tag;
        tr[rs].Max += tr[i].tag;
        tr[ls].tag += tr[i].tag, tr[rs].tag += tr[i].tag;
        tr[i].tag = 0;
    }
}
void modify(int i, int l, int r, int v) {
    if (l <= tr[i].l && tr[i].r <= r) {
        tr[i].Max += v;
        tr[i].tag += v;
        return ;
    }
    pushdown(i);
    if (l <= tr[ls].r) modify(ls, l, r, v);
    if (r >= tr[rs].l) modify(rs, l, r, v);
    pushup(i);
}
int query(int i, int l, int r) {
    if (l <= tr[i].l && tr[i].r <= r) return tr[i].Max;
    pushdown(i);
    int res = 0;
    if (l <= tr[ls].r) res = max(res, query(ls, l, r));
    if (r >= tr[rs].l) res = max(res, query(rs, l, r));
    return res;
}
void dfs1(int u) {
    dfn[u] = ++tot, re[tot] = u;
    for (int i = head[u]; i; i = e[i].second) {
        int v = e[i].first;
        dfs1(v);
    }
    ed[u] = tot;
}
void dfs2(int u) {
    int Max1 = 0, Max2 = 0;
    for (int i = head[u]; i; i = e[i].second) {
        int v = e[i].first;
        dfs2(v);
        while (!s[a[u]].empty() && *s[a[u]].begin() > dfn[u]) {
            int l = *s[a[u]].begin(), r = ed[re[*s[a[u]].begin()]];
            modify(1, l, r, -1);
            s[a[u]].erase(s[a[u]].begin());
        }
        int t = query(1, dfn[v], ed[v]);
        if (t > Max1) Max2 = Max1, Max1 = t;
        else if (t > Max2) Max2 = t;
    }
//    while (!s[a[u]].empty() && *s[a[u]].begin() > dfn[u]) {
//        int l = *s[a[u]].begin(), r = ed[re[*s[a[u]].begin()]];
//        modify(1, l, r, -1);
//    }
//    for (int i = head[u]; i; i = e[i].second) {
//        int v = e[i].first;
//        int t = query(1, dfn[v], ed[v]);
//        if (t > Max1) Max2 = Max1, Max1 = t;
//        else if (t > Max2) Max2 = t;
//    }
//    cout << "Max1: " << Max1 << " Max2: " << Max2 << endl;
    ans = max((Max1 + 1) * (Max2 + 1), ans);
    modify(1, dfn[u], ed[u], 1);
    s[a[u]].insert(dfn[u]);
}
signed main() {
    int cnt = 0;
    auto add = [] (int u, int v, int &cnt) {
        e[++cnt] = {v, head[u]}, head[u] = cnt;
    };
    auto init = [] (int n ,int &cnt) {
        ans = 1, cnt = tot = 0;
        memset(tr + 1, 0, sizeof(p) * 4 * n);
        for (int i = 1; i <= n; ++i) s[i].clear(), head[i] = 0;
    };
    int T = 1;
    T = read();
    while (T--) {
        int n = read();
        init(n, cnt);
        for (int i = 2; i <= n; ++i) {
            int pi = read();
            add(pi, i, cnt);
        }
        for (int i = 1; i <= n; ++i) a[i] = read();
        build(1, 1, n);
        dfs1(1);
        dfs2(1);
        writeln(ans);
    }
    return 0;
}
//12
//1 1 1 2 2 3 4 4 7 7 6
//11 2 1 11 12 8 5 8 8 5 11 7

你可能感兴趣的:(codeforces题解,算法,c++,思维,树形数据,数据结构,dfs序,线段树)