[BJOI2017]树的难题:点分治+单调队列

题意

luogu / loj

给你一棵 n n n 个点的无根树。

树上的每条边具有颜色。一共有 m m m 种颜色,编号为 1 1 1 m m m。第 i i i 种颜色的权值为 c i c_i ci

对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。

请你计算,经过边数在 l l l r r r 之间的所有简单路径中,路径权值的最大值。

数据范围: 1 ≤ n , m ≤ 2 × 1 0 5 , ∣ c i ∣ ≤ 1 0 4 1\le n,m\le 2\times 10^5, |c_i|\le 10^4 1n,m2×105,ci104

题解

我们称边都在 u u u 的子树里面,且一个端点为 u u u 的路径是 u u u “引上来” 的一条路径。

将题目中的 { c } \{c\} {c} 数组用 { cost } \{\text{cost}\} {cost} 代替。

考虑点分治。

设现在在 u u u,其子为 v 1 , v 2 , ⋯   , v k v_1, v_2, \cdots, v_k v1,v2,,vk,相邻边的颜色为 c 1 , c 2 , ⋯   , c k c_1, c_2, \cdots, c_k c1,c2,,ck

我们首先将它们按颜色排序(这一步不会增大复杂度)这样相同颜色的儿子都在一块了。

考虑若 c i ≠ c j c_i\not= c_j ci=cj,则 v i v_i vi 引上来一条路径, v j v_j vj 引上来一条路径,它们拼起来的权值为两者之和。如果是相等的,那么要额外减去 cost c i \text{cost}_{c_i} costci

因此,当我们考虑到儿子 V V V,设 d V , i d_{V,i} dV,i V V V 的子树中到 V V V 的边数为 i i i,且权值最大者的权值。这样我们就可以使用单调队列进行计算。

具体地,对于与 V V V 不同的儿子 v 1 , v 2 , ⋯   , v t v_1, v_2, \cdots, v_t v1,v2,,vt,设 D i = max ⁡ 1 ≤ j ≤ t d v i , j D_i=\max_{1\le j\le t} d_{v_i,j} Di=max1jtdvi,j,然后我们只需要找寻 max ⁡ l ≤ i + j ≤ r d i + D j \max\limits_{l\le i+j\le r} d_i+D_j li+jrmaxdi+Dj,这可以通过枚举 i i i,单调队列求 D j D_j Dj 的最大值来做到。

对于与 V V V 相同的儿子,使用类似的手法可以做到。

但是这样是会出问题的。。跑单调队列的复杂度是 V V V 子树的深度最大值,但单调队列初始化的复杂度是 O ( mx ) O(\text{mx}) O(mx),其中 mx \text{mx} mx 是之前的儿子的子树深度最大值。如果我们访问的第一个儿子深度就极大,我们就没了。

因此我们更改对儿子的排序方法,以相同颜色的儿子深度最大值为第一关键字,自己子树的深度最大值为第二关键字排序,同颜色内一个一个做单调队列,做完一个颜色后与前面的所有颜色一起做单调队列,就能使复杂度达到 O ( n log ⁡ n ) O(n\log n) O(nlogn)

代码

#include 
#include 
using namespace std;

const int MAXN = 500005, INF = 2147483647;

int N, M, L, R, Cost[MAXN], ans;
struct node { int v, next, k; } E[MAXN << 1]; int head[MAXN], Elen;
void add(int u, int v, int k) { ++Elen, E[Elen].v = v, E[Elen].next = head[u], head[u] = Elen, E[Elen].k = k; }

int siz[MAXN], Sum, Rt, Maxpart; bool vis[MAXN];
void getRt(int u, int ff) {
    siz[u] = 1; int maxpart = 0;
    for (int i = head[u]; i; i = E[i].next) if (E[i].v != ff && !vis[E[i].v]) getRt(E[i].v, u), siz[u] += siz[E[i].v], maxpart = max(maxpart, siz[E[i].v]);
    maxpart = max(maxpart, Sum - siz[u]);
    if (maxpart < Maxpart) Maxpart = maxpart, Rt = u;
}

int dep[MAXN];
int maxDep, d[MAXN];
int sameMaxDep, sameD[MAXN];
int MaxDep, D[MAXN];
int getDep(int u, int ff, int d) {
    int ret = d;
    for (int i = head[u]; i; i = E[i].next) if (E[i].v != ff && !vis[E[i].v]) ret = max(ret, getDep(E[i].v, u, d + 1));
    return ret;
}
void getDis(int u, int ff, int di, int las) {
    dep[u] = dep[ff] + 1, maxDep = max(maxDep, dep[u]);
    for (int i = head[u]; i; i = E[i].next) {
        if (E[i].v == ff || vis[E[i].v]) continue;
        if (E[i].k != las) getDis(E[i].v, u, di + Cost[E[i].k], E[i].k);
        else getDis(E[i].v, u, di, E[i].k);
    }
    d[dep[u]] = max(d[dep[u]], di);
}

int colDep[MAXN];
struct qaq { int u, col, maxd; } a[MAXN]; int cnt;
bool cmp(qaq n1, qaq n2) {
    if (colDep[n1.col] != colDep[n2.col]) return colDep[n1.col] < colDep[n2.col];
    else if (n1.col == n2.col) return n1.maxd < n2.maxd;
    else return n1.col < n2.col;
}

struct triple { int val, rnk; } Q[MAXN]; int H, T;
void pop(int lef) {
    while (H <= T && Q[H].rnk < lef) ++H;
}
void insert(const triple& tr) {
    while (H <= T && Q[T].val <= tr.val) --T; Q[++T] = tr;
}

void Solve() {
    vis[Rt] = 1; int i, j;
    cnt = 0;
    for (i = head[Rt]; i; i = E[i].next) if (!vis[E[i].v]) {
        int d = getDep(E[i].v, 0, 1);
        a[++cnt] = qaq{E[i].v, E[i].k, d}, colDep[E[i].k] = max(colDep[E[i].k], d);
    }
    sort(a + 1, a + cnt + 1, cmp);
    for (i = 1; i <= cnt; ++i) {
        getDis(a[i].u, 0, Cost[a[i].col], a[i].col);

        H = 1, T = 0;
        for (j = max(0, L - maxDep); j <= min(sameMaxDep, R - maxDep); ++j) if (j) insert(triple{sameD[j], j});
        for (j = maxDep; j >= 1; --j) {
            if (H <= T) ans = max(ans, d[j] + Q[H].val - Cost[a[i].col]);
            pop(L - j + 1);
            if (R - j + 1 >= 0 && R - j + 1 <= sameMaxDep) insert(triple{sameD[R - j + 1], R - j + 1});
        }
        sameMaxDep = max(sameMaxDep, maxDep);
        for (j = 1; j <= maxDep; ++j) sameD[j] = max(sameD[j], d[j]);

        if (a[i].col != a[i + 1].col) {
            H = 1, T = 0;
            for (j = max(0, L - maxDep); j <= min(MaxDep, R - maxDep); ++j) insert(triple{D[j], j});
            for (j = maxDep; j >= 1; --j) {
                if (H <= T) ans = max(ans, sameD[j] + Q[H].val);
                pop(L - j + 1);
                if (R - j + 1 >= 0 && R - j + 1 <= MaxDep) insert(triple{D[R - j + 1], R - j + 1});
            }
            for (j = 1; j <= sameMaxDep; ++j) D[j] = max(D[j], sameD[j]), sameD[j] = -INF;
            MaxDep = max(MaxDep, sameMaxDep), sameMaxDep = 0;
        }
        for (j = 1; j <= maxDep; ++j) d[j] = -INF;
        maxDep = 0;
    }
    for (i = 1; i <= sameMaxDep; ++i) sameD[i] = -INF;
    sameMaxDep = 0;
    for (i = 1; i <= MaxDep; ++i) D[i] = -INF;
    MaxDep = 0;
    for (i = head[Rt]; i; i = E[i].next) if (!vis[E[i].v]) colDep[E[i].k] = -INF;
    for (i = head[Rt]; i; i = E[i].next) if (!vis[E[i].v]) Sum = siz[E[i].v], Rt = 0, Maxpart = INF, getRt(E[i].v, 0), Solve();
}

int main() {
    scanf("%d%d%d%d", &N, &M, &L, &R); int i, u, v, k;
    for (i = 1; i <= M; ++i) scanf("%d", &Cost[i]);
    for (i = 1; i < N; ++i) scanf("%d%d%d", &u, &v, &k), add(u, v, k), add(v, u, k);
    for (i = 1; i <= N; ++i) d[i] = D[i] = sameD[i] = -INF; ans = -INF;
    for (i = 1; i <= M; ++i) colDep[i] = -INF;
    Sum = N, Maxpart = INF, getRt(1, 0), Solve();
    printf("%d\n", ans);
    return 0;
}

你可能感兴趣的:([BJOI2017]树的难题:点分治+单调队列)