12. Game on Tree 3

题目链接:Game on Tree 3

有一棵含有 n n n 个节点的树,节点编号从 1 1 1 n n n,根节点为 1 1 1,所有非根节点均有一个正整数权值。根节点上放有一个棋子。T 和 A 两个人正在玩一个回合制游戏。一个回合中:

  • A 先选取一个非根节点,将其权值变为 0 0 0
  • 然后 T 将棋子移动到当前位置的任意一个儿子上
  • 若棋子位于叶子节点,游戏结束;T 也可以在此时强行结束游戏

游戏结束时 T 会获得棋子所在位置的权值的得分。T 想最大化得分,而 A 想最小化得分,问两人在最优策略下 T 最后的得分是多少。

先看官方题解:

由于验证 T 能否至少得到 x x x 分比较容易,我们可以二分他的得分。假设当前 check 的是至少得到 x x x 分的情况,则将树上权值小于 x x x 的节点染成白色,权值大于等于 x x x 的节点染成黑色,然后进行树形 dp:设 d p [ u ] dp[u] dp[u] 表示在以 u u u 为根节点的子树中 A 需要额外染色 d p [ u ] dp[u] dp[u] 次才能使 T 无法走到黑色节点。那么状态转移方程:

d p [ u ] = max ⁡ ( ∑ d p [ v ] − 1 , 0 ) + [ v a l u ≥ x ] dp[u]=\max(\sum dp[v]-1,0)+[val_u\ge x] dp[u]=max(dp[v]1,0)+[valux]

其中 v v v u u u 的子节点。求和后减一是因为在 T 走下去之前还有一次变颜色的机会。

最后如果 d p [ 1 ] > 0 dp[1]\gt 0 dp[1]>0 说明可以取到大于等于 x x x 的权值。

#include 
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
vector<int> g[maxn];
int a[maxn], dp[maxn];
void dfs(int u, int f, int x) {
    dp[u] = 0;
    for (auto v : g[u]) {
        if (v == f)
            continue;
        dfs(v, u, x);
        dp[u] += dp[v];
    }
    dp[u] = max(dp[u] - 1, 0) + (a[u] >= x);
}
void solve() {
    int n;
    cin >> n;
    for (int i = 2; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 1, u, v; i < n; ++i) {
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    int l = 0, r = 1e9, ans = 0;
    while (l <= r) {
        int mid = (l + r) >> 1;
        dfs(1, 0, mid);
        if (dp[1] > 0)
            l = mid + 1, ans = mid;
        else
            r = mid - 1;
    }
    cout << ans << endl;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T = 1;
    // cin >> T;
    while (T--) {
        solve();
    }
}

这种染色的技巧在一些数据结构题中也有出现,在不太容易直接计算但比较容易 check 的情况下可以尝试一下。

那么能不能直接求出这个答案呢?我的室友给出了一种更为巧妙的做法:

先考虑树的高度是 1 1 1 的情况,显然 A A A 的最优选择是改变权值最大的那个叶子节点。

如果上面的这个东西是一个子树,那么它就会向父亲的地方输送除去这个权值之外的所有权值。然后又会产生一个改变权值的机会,所以就从这些权值里再删掉一个最大的。这个过程可以用可并堆来维护。

#include 
using namespace std;
using ll = long long;
const int maxn = 2e5 + 5;
const ll mod = 998244353;
vector<int> g[maxn];
int a[maxn];
int fa[maxn], ls[maxn], rs[maxn], d[maxn];
int findfa(int x) { 
    return fa[x] == x ? fa[x] : (fa[x] = findfa(fa[x]));
}
int merge(int x, int y) {
    if (!x || !y) {
        d[x] = d[y] = 0;
        return x + y;
    }
    if (a[x] < a[y]) 
        swap(x, y);
    rs[x] = merge(rs[x], y);
    if (d[ls[x]] < d[rs[x]])
        swap(ls[x], rs[x]);
    d[x] = d[rs[x]] + 1;
    return x;
}
int join(int x, int y) {
    x = findfa(x), y = findfa(y);
    if (x == y) 
        return x;
    fa[x] = fa[y] = merge(x, y);
    return fa[x];
}
int pop(int x) {
    x = findfa(x);
    int t = x, y = rs[x];
    x = ls[x];
    fa[x] = fa[y] = fa[t] = merge(x, y);
    return fa[x];
}
int dfs(int u, int f) {
    if (u != 1 && g[u].size() == 1)
        return u;
    int rt = 0;
    for (auto v: g[u]) {
        if (v == f)
            continue;
        int r = dfs(v, u);
        if (!rt)
            rt = r;
        else 
            rt = join(rt, r);
    }
    rt = pop(rt);
    return join(rt, u);
}
void solve() {
    int n;
    cin >> n;
    for (int i = 2; i <= n; ++i) {
        cin >> a[i];
    }
    for (int i = 1; i <= n; ++i) {
        fa[i] = i;
    }
    for (int i = 1, u, v; i < n; ++i) {
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    int ans = dfs(1, 0);
    cout << a[ans] << endl;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T = 1;
    // cin >> T;
    while (T--) {
        solve();
    }
}

你可能感兴趣的:(一题,算法,c++)