N N N个点以点 1 1 1为根的树,在树上确定 K K K个关键点,每个点的权值 v a l val val为点与点到根节点上碰到的第一个关键点的距离(若路径上没有关键点, 那么权值为 inf \inf inf),答案为所有点中最大权值的最小值。
现在求 K = 1 , 2 , . . . , N K=1,2,...,N K=1,2,...,N的答案之和
题意比较难理解,看看样例应该能懂
如果确定 K K K值,那么我们单次得到答案应该是可以用树形dp来解决的,总体的复杂度就是 O ( N 2 ) O(N^2) O(N2)
现在反过来考虑,考虑已经确定答案的情况下,最少需要多少关键点
因为点的权值为点与点到根节点上碰到的第一个关键点的距离,其实就是点与其祖先结点中的关键点的深度之差
这里确定答案为 a n s ans ans后,我们要将树中深度 d e p > a n s dep>ans dep>ans都处理掉
对于树中最深的点 u u u,将从 u u u一直往上 a n s ans ans个点得到的 a n c e s t o r ancestor ancestor设为关键点,这样 v a l [ u ] = a n s val[u]=ans val[u]=ans
至于为什么恰好是 a n s ans ans级祖先,因为关键点只要是 u u u的 [ 0 , a n s ] [0, ans] [0,ans]级祖先就可以使 v a l [ u ] ≤ a n s val[u] \leq ans val[u]≤ans
考虑到一个关键点 k e y key key的设立,会使得 k e y key key的子树中所有的点权值都变小,所以我们要让关键点影响的点越多越好(离根越近越好),即 d e p [ k e y ] dep[key] dep[key]越小越好,所以我们选择恰好是 a n s ans ans级祖先
所以每次我们确定一个 k e y key key后,其子树中的点均有 v a l ≤ a n s val \leq ans val≤ans,当我们继续找深度最大的点,是不会再考虑 k e y key key及其子树中的点,因此我们将他们覆盖掉,即我们每次找未被覆盖的点中 d e p dep dep最大的点,不断重复这个过程,直到所有点均满足 v a l ≤ a n s val \leq ans val≤ans
可以看到,我们需要的操作是寻找树中深度最大的结点,以及覆盖一个点的子树,所以考虑按dfs序建立线段树
void dfs(int u, int fa) {
st[u] = ++dfnt;//dfs到这个点的时刻
for (auto &v: g[u]) if (v != fa) dfs(v, u);
ed[u] = dfnt;//dfs完这个点及其子树的时刻
//所以[st, ed]这之间就是u及其子树
}
可以看到,只要我们覆盖线段树中 [ s t [ u ] , e d [ u ] ] [st[u], ed[u] ] [st[u],ed[u]]即可覆盖点 u u u子树中所有结点
至于查询最大深度结点,就是线段树基本操作了
还有一个操作,我们要查找一个点 u u u的 k − t h k-th k−th祖先
这里可以倍增处理:
void dfs() {
//dfs过程中预处理每个点的2^i级祖先
for (int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
}
int kthFa(int u, int k) {//查询的时候倍增跳祖先, 复杂度log(k)
int bit = 0;
while (k) {
if (k & 1) u = anc[u][bit];
k >>= 1;
bit++;
}
return u;
}
对于每一个答案 a n s ans ans,我们至多确定 ⌈ N a n s + 1 ⌉ \lceil \frac{N}{ans+1} \rceil ⌈ans+1N⌉个关键点
考虑关键点会对子树中所有点有影响,那么我们最坏的情况就是子树是一条链,所影响的点只有 a n s + 1 ans+1 ans+1个
所以每个答案只要确定 ⌈ N a n s + 1 ⌉ \lceil \frac{N}{ans+1} \rceil ⌈ans+1N⌉个关键点就能影响到所有点
而我们单次操作就是线段树的覆盖及其查询,还有查找k级祖先,这些操作都是 l o g N logN logN级别的
a n s = 0 , 1 , . . . , N − 1 ans = 0,1,...,N-1 ans=0,1,...,N−1
T = ⌈ N 1 ⌉ l o g N + ⌈ N 2 ⌉ l o g N + ⋅ ⋅ ⋅ + ⌈ N n ⌉ l o g N = O ( N l o g 2 N ) T=\lceil \frac{N}{1} \rceil logN + \lceil \frac{N}{2} \rceil logN + ···+ \lceil \frac{N}{n} \rceil logN=O(Nlog^2N) T=⌈1N⌉logN+⌈2N⌉logN+⋅⋅⋅+⌈nN⌉logN=O(Nlog2N)
总体复杂度 O ( N l o g 2 N ) O(Nlog^2N) O(Nlog2N)
#include
#define lc u<<1
#define rc u<<1|1
#define mid (t[u].l+t[u].r)/2
using namespace std;
typedef long long ll;
const int MAX = 2e5 + 10;
int N;
int ans[MAX];
vector<int> store;
vector<int> g[MAX];
int anc[MAX][20], dep[MAX], st[MAX], ed[MAX], nodeOf[MAX], dfnt;
void dfs(int u, int fa) {
dep[u] = dep[anc[u][0] = fa] + 1, nodeOf[st[u] = ++dfnt] = u;
for (int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
for (auto &v: g[u])
if (v != fa) dfs(v, u);
ed[u] = dfnt;
}
//k级祖先, 倍增跳
int kthFa(int u, int k) { int bit = 0; while (k) { if (k & 1) u = anc[u][bit]; k >>= 1; bit++;} return u; }
struct SegmentTree {
int mx, node, l, r;
bool cover;
} t[MAX << 2];
void push_up(int u) {
t[u].mx = 0;
if (!t[lc].cover && t[lc].mx > t[u].mx) t[u].node = t[lc].node, t[u].mx = t[lc].mx;
if (!t[rc].cover && t[rc].mx > t[u].mx) t[u].node = t[rc].node, t[u].mx = t[rc].mx;
}
void build(int u, int l, int r) {
t[u].l = l, t[u].r = r;
if (l == r) {
t[u].mx = dep[t[u].node = nodeOf[l]];
return;
}
build(lc, l, mid); build(rc, mid + 1, r);
push_up(u);
}
void update(int u, int ql, int qr, int k) {
if (ql <= t[u].l && t[u].r <= qr) {
t[u].cover = k;
return;
}
if (ql <= mid) update(lc, ql, qr, k);
if (qr > mid) update(rc, ql, qr, k);
push_up(u);
}
void init() {
dfnt = 0;
for (int i = 1; i <= N; i++) g[i].clear();
}
int main() {
while (~scanf("%d", &N)) {
init();
for (int i = 2; i <= N; i++) {
int x; scanf("%d", &x);
g[x].push_back(i); g[i].push_back(x);
}
dep[0] = -1; dfs(1, 0);
build(1, 1, dfnt);//按dfs序建立线段树
for (int i = 1; i <= N; i++) ans[i] = N - 1;
for (int nowans = N - 1; nowans >= 0; nowans--) {
int cost = 0;
store.clear();
while (1) {
cost++;
if (t[1].mx <= nowans) break;//如果树中未覆盖的最大值<=ans, 那么就不需要再覆盖了
int u = t[1].node;
u = kthFa(u, nowans);//让最深的点u的贡献值变为nowans, 所以从u往上找nowans个点, 即u的nowans-th祖先
store.push_back(u);//将覆盖的部分记录, 之后线段树清空
update(1, st[u], ed[u], 1);//覆盖u的子树, 因为我们将u设为关键点, 且u中最深的点都满足<=nowans, 因此整个子树都满足
}
ans[cost] = nowans;//设cost个关键点的答案
for (auto &i: store) update(1, st[i], ed[i], 0);//清空线段树
}
for (int i = 2; i <= N; i++) ans[i] = min(ans[i], ans[i - 1]);//如果设更少的关键点的答案能更小, 那答案就取小的
ll sum = 0;
for (int i = 1; i <= N; i++) sum += ans[i];
printf("%lld\n", sum);
}
return 0;
}