Ancient Distance

链接:https://ac.nowcoder.com/acm/contest/5669/A
来源:牛客网

题目描述
As a member of Coffee Chicken, ZYB is a boy with excellent data structure skills.

Consider the following problem: Give a rooted tree with N {N} N vertices. Vertices are numbered from 1 {1} 1 to N {N} N, and the root is always vertex 1 {1} 1. You are allowed to assign at most K {K} K key vertices so that the maximum ancient distances among all vertices is as small as possible. Denote the ancient distance of vertex x {x} x as: The distance between x {x} x and the first key vertex on the path from x {x} x to the root. For example, if the tree is 1 − 2 − 3 1-2-3 123 and the set of key vertices is { 2 } {\{2\}} {2}, then the ancient distances of all vertices are { + ∞ , 0 , 1 } \{+\infty,0,1\} {+,0,1}.

ZYB then strengthens this problem: Please find the answer for each K ∈ { 1 , 2 , … , N } K \in \{1,2,\dots,N\} K{1,2,,N}. Could you accept ZYB’s challenge?
输入描述:
The input contains multiple test cases.
For each test case, the first line contains a integer N ( 1 ≤ N ≤ 2 × 1 0 5 ) N(1 \le N \le 2 \times 10^5) N(1N2×105), indicating the number of vertices in the tree.
In the second line, there are N − 1 {N-1} N1 integers, the i {i} i-th integer f i ( 1 ≤ f i ≤ i ) f_i(1 \le f_i \le i) fi(1fii) means that there is an edge between f i f_i fi and i + 1 {i+1} i+1.
It’s guaranteed that there are at most 5 {5} 5 test cases with N > 1000 {N>1000} N>1000, and the sum of N {N} N over all test cases will not exceed 1.2 × 1 0 6 1.2 \times 10^6 1.2×106.
输出描述:
For each test case, you should output the sum of all answers instead of each of them.
示例1

输入
3
1 2
3
1 1
输出
3
2

说明
The answer for the first test case is { 2 , 1 , 0 } {\{2,1,0\}} {2,1,0}.
The answer for the second test case is { 1 , 1 , 0 } {\{1,1,0\}} {1,1,0}.
备注:
If you have any possible solution, try it bravely!
直接根据关键节点的数量 K K K求最小化的最大ancient distances不好求,但是如果给定最长的ancient distances求所需要的关键节点的最小数量比较容易。
具体做法是枚举最长距离 x x x,然后在树上未标记的节点中找最深的节点,为保证最深的节点的ancient distances至多为 x x x,同时根据贪心原则,深度越浅的关键节点相应子树节点越多,因此将 x x x节点向上第 x x x个祖先设为关键节点,可以保证尽可能多的节点ancient distances不大于 x x x。将该关键节点都标记一下。最后可以得到最长的ancient distances为 x x x时所需要的关键节点的最小数量。
从大到小枚举 x x x可以保证最终节点数量为 K K K时的最小距离 x x x被更新成最小。
由于每次求得的是 K K K的最小值,最后可能有一些 K K K对应的 x x x没有被更新到,这些 K K K对应的 x x x值是所有小于 K K K的对应 x x x中的最小值。
为防止重新建线段树导致复杂度退化,每次修改线段树时要记录修改的区间,以便恢复。
向上找第 x x x祖先个祖先可以通过树上倍增预处理做到每次 O ( l o g n ) O(logn) O(logn)复杂度,标记节点可以在dfs序上建线段树实现 O ( l o g n ) O(logn) O(logn)复杂度标记。最坏情况整个树是一条链,这样总共需要标记的关键节点数为 n 1 + n 2 + n 2 + ⋯ + n n = n H n \frac{n}{1}+\frac{n}{2}+\frac{n}{2}+\dots+\frac{n}{n}=nH_n 1n+2n+2n++nn=nHn。每次标记关键节点时倍增和线段树复杂度为 O ( l o g n ) O(logn) O(logn)。因此总体时间复杂度为 O ( n H n l o g n ) O(nH_nlogn) O(nHnlogn)

#include 

#define si(a) scanf("%d",&a)
#define sl(a) scanf("%lld",&a)
#define sd(a) scanf("%lf",&a)
#define sc(a) scahf("%c",&a);
#define ss(a) scanf("%s",a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%lld\n",a)
#define pc(a) putchar(a)
#define ms(a) memset(a,0,sizeof(a))
#define repi(i, a, b) for(register int i=a;i<=b;++i)
#define repd(i, a, b) for(register int i=a;i>=b;--i)
#define reps(s) for(register int i=head[s];i;i=Next[i])
#define ll long long
#define ull unsigned long long
#define vi vector
#define pii pair
#define mii unordered_map
#define msi unordered_map
#define lowbit(x) ((x)&(-(x)))
#define ce(i, r) i==r?'\n':' '
#define pb push_back
#define fi first
#define se second
#define INF 0x3f3f3f3f
#define pr(x) cout<<#x<<": "<
using namespace std;

inline int qr() {
    int f = 0, fu = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-')fu = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        f = (f << 3) + (f << 1) + c - 48;
        c = getchar();
    }
    return f * fu;
}

const int N = 2e5 + 10;
int head[N], ver[N << 1], Next[N << 1], tot;
int id[N], L[N], R[N], num;
int f[N][20], d[N], ans[N];
int n, t;
vi tmp;

struct Seg_Tree {
    struct {
        int l, r, x, d;
    } t[N * 4], tmp[N * 4];

    void build(int p, int l, int r) {
        t[p].l = l, t[p].r = r;
        if (l == r) {
            t[p].d = d[id[l]], t[p].x = id[l], tmp[p] = t[p];
            return;
        }
        int mid = (l + r) >> 1;
        build(p << 1, l, mid);
        build(p << 1 | 1, mid + 1, r);
        if (t[p << 1].d >= t[p << 1 | 1].d)
            t[p].d = t[p << 1].d, t[p].x = t[p << 1].x;
        else
            t[p].d = t[p << 1 | 1].d, t[p].x = t[p << 1 | 1].x;
        tmp[p] = t[p];
    }

    void change(int p, int l, int r, int op) {
        if (op == -1 && t[p].d == -1)return;
        if (l <= t[p].l && r >= t[p].r) {
            if (op == -1)t[p].d = -1;
            else t[p] = tmp[p];
            return;
        }
        int mid = (t[p].l + t[p].r) >> 1;
        if (l <= mid)change(p << 1, l, r, op);
        if (r > mid) change(p << 1 | 1, l, r, op);
        if (op == -1) {
            if (t[p << 1].d >= t[p << 1 | 1].d)
                t[p].d = t[p << 1].d, t[p].x = t[p << 1].x;
            else t[p].d = t[p << 1 | 1].d, t[p].x = t[p << 1 | 1].x;
        } else t[p] = tmp[p];
    }
} tr;

inline void add(int x, int y) {
    ver[++tot] = y;
    Next[tot] = head[x];
    head[x] = tot;
}

inline void init() {
    t = log2(n), tot = 0, num = 0, d[1] = 1;
    repi(i, 1, n)head[i] = 0, ans[i] = INF;
}

void dfs(int x, int fa) {
    id[++num] = x, L[x] = num;
    reps(x) {
        int y = ver[i];
        if (y == fa)continue;
        f[y][0] = x, d[y] = d[x] + 1;
        repi(j, 1, t)f[y][j] = f[f[y][j - 1]][j - 1];
        dfs(y, x);
    }
    R[x] = num;
}

inline int find(int x, int u) {
    int e = log2(u);
    repi(i, 0, e)if ((u >> i) & 1)x = f[x][i];
    return max(1, x);
}

inline int solve(int x) {
    tmp.clear();
    while (tr.t[1].d != -1) {
        int fa = find(tr.t[1].x, x);
        tmp.pb(fa);
        tr.change(1, L[fa], R[fa], -1);
    }
    for (auto it:tmp)tr.change(1, L[it], R[it], 1);
    return tmp.size();
}

int main() {
    while (scanf("%d", &n) != EOF) {
        init();
        repi(i, 2, n) {
            int x = qr();
            add(i, x), add(x, i);
        }
        dfs(1, 0);
        tr.build(1, 1, n);
        repd(i, n - 1, 0) ans[solve(i)] = i;
        ll sum = 0;
        repi(i, 2, n)ans[i] = min(ans[i], ans[i - 1]);
        repi(i, 1, n)sum += ans[i];
        pl(sum);
    }
    return 0;
}

你可能感兴趣的:(ACM,线段树,倍增)