【NOIP 2018】保卫王国
给定一棵 n n n 个结点的树,点有权值。有 q q q 次询问,每次要求固定两个点是否染黑,求将树上的所有其他点染色,每条边的两端至少有一个黑点的所有方案中被染黑的点的最小权值之和。
我们可以通过树形 DP \texttt{DP} DP 在线性时间内求出一个点 u u u 如果颜色为 c c c,那么整棵树的最小代价为 f u , c f_{u, c} fu,c(具体的做法就是先从下往上树形 DP \texttt{DP} DP 得出点 u u u 如果选 c c c 这个颜色的话整个子树中的最小代价为 dp1 u , c \texttt{dp1}_{u,c} dp1u,c,然后从上往下 DP \texttt{DP} DP 得出点 u u u 的父亲如果选 c c c 这个颜色的话,以 u u u 为根 u u u 父亲子树中的最小代价为 dp2 u , c \texttt{dp2}_{u, c} dp2u,c,具体细节不在此详细赘述)。
发现如果固定两个点 u , v u, v u,v 的颜色分别为 c , d c, d c,d,那么就应该对于树上 u u u 到 v v v 路径上(不包含 u u u 和 v v v)的每一个点分别考虑是否染黑。因为 dp \texttt{dp} dp 信息是可减的,所以如果那条链上的染色的方案已经确定下来了,我们容易算出总代价:考虑路径上从上往下连续的三个点 u , v , w u, v, w u,v,w, v v v 颜色为 c c c,那么 v v v 的贡献就是 f v , c − dp1 w , d − dp2 v , d f_{v, c} - \texttt{dp1}_{w, d} - \texttt{dp2}_{v, d} fv,c−dp1w,d−dp2v,d,其中 d d d 表示能够和 c c c 相临的颜色。我们把这个贡献算在 ( v , w ) (v, w) (v,w) 这条边上。
考虑如何确定最优的链上染色方案。对于树上从上往下的两条链,其中一条链顶端的父亲是另一条链的底端,我们要合并这两条链的信息。发现 DP \texttt{DP} DP 转移只和链的两端的颜色有关,所以对于一条链只需记录它两边是否染黑即可。合并的时候枚举相邻两点的颜色,如果不全为 0 0 0 则合法。于是,我们考虑倍增。令 g i , j , k g_{i, j, k} gi,j,k 表示 i i i 点向上长度为 2 j 2^j 2j 的链,顺序为从下往上或者从上往下的 DP \texttt{DP} DP 值。这样就可以通过倍增转移,询问时像查询 LCA \texttt{LCA} LCA 一样查询即可。时间复杂度 O ( ( n + q ) log n ) O((n + q) \log n) O((n+q)logn)。
细节较多,注意特判询问时某个点在另一个点子树中的情况。
#include
#include
using namespace std;
typedef long long llong;
const int maxn = 1e5, maxm = 2e5, logn = 16; const llong infl = 1e18 + 1e9 + 1;
int n, m, a[maxn + 3], tot, ter[maxm + 3], nxt[maxm + 3], lnk[maxn + 3];
int dep[maxn + 3], cnt, l[maxn + 3], r[maxn + 3], fa[maxn + 3][logn + 3];
llong dp1[maxn + 3][2], dp2[maxn + 3][2], f[maxn + 3][2];
inline void upd_min(llong &a, llong b) {
a = min(a, b);
}
struct node {
llong dp[2][2];
node() { dp[0][0] = dp[0][1] = dp[1][0] = dp[1][1] = infl; }
llong get_min() {
llong ans = infl;
for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) {
upd_min(ans, dp[i][j]);
}
return ans;
}
friend inline node merge(const node &a, const node &b) {
node c;
for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) for (int l = 0; l < 2; l++) {
if (k || l) upd_min(c.dp[i][j], a.dp[i][k] + b.dp[l][j]);
}
}
return c;
}
} g[maxn + 3][logn + 3][2];
void add_edge(int u, int v) {
ter[++tot] = v;
nxt[tot] = lnk[u];
lnk[u] = tot;
}
void dfs1(int u, int p) {
dp1[u][1] = a[u];
for (int e = lnk[u], v; e; e = nxt[e]) {
if ((v = ter[e]) == p) continue;
dfs1(v, u);
dp1[u][0] += dp1[v][1];
dp1[u][1] += min(dp1[v][0], dp1[v][1]);
}
}
void dfs2(int u, int p) {
for (int e = lnk[u], v; e; e = nxt[e]) {
if ((v = ter[e]) == p) continue;
dp2[v][0] = dp2[u][1] + dp1[u][0] - dp1[v][1];
dp2[v][1] = min(dp2[u][0], dp2[u][1]) + dp1[u][1] - min(dp1[v][0], dp1[v][1]);
dfs2(v, u);
}
f[u][0] = dp2[u][1];
f[u][1] = a[u] + min(dp2[u][0], dp2[u][1]);
for (int e = lnk[u], v; e; e = nxt[e]) {
if ((v = ter[e]) == p) continue;
f[u][0] += dp1[v][1];
f[u][1] += min(dp1[v][0], dp1[v][1]);
}
}
void dfs3(int u, int p) {
dep[u] = dep[p] + 1, fa[u][0] = p;
l[u] = r[u] = ++cnt;
llong A = f[p][0] - dp1[u][1] - dp2[p][1];
llong B = f[p][1] - min(dp1[u][0], dp1[u][1]) - min(dp2[p][0], dp2[p][1]);
g[u][0][0].dp[0][0] = A, g[u][0][0].dp[1][1] = B;
g[u][0][1].dp[0][0] = A, g[u][0][1].dp[1][1] = B;
for (int i = 0, t; (t = fa[fa[u][i]][i]); i++) {
fa[u][i + 1] = t;
g[u][i + 1][0] = merge(g[u][i][0], g[fa[u][i]][i][0]);
g[u][i + 1][1] = merge(g[fa[u][i]][i][1], g[u][i][1]);
}
for (int e = lnk[u], v; e; e = nxt[e]) {
if ((v = ter[e]) == p) continue;
dfs3(v, u), r[u] = r[v];
}
}
llong solve(int u, int a, int v, int b) {
if (dep[u] > dep[v]) swap(u, v), swap(a, b);
node A, B;
B.dp[b][b] = f[v][b] - (!b ? dp2[v][1] : min(dp2[v][0], dp2[v][1]));
if (l[u] <= l[v] && l[v] <= r[u]) {
int diff = dep[v] - dep[u] - 1;
for (int i = 0; i <= logn; i++) {
if (diff >> i & 1) {
B = merge(g[v][i][1], B);
v = fa[v][i];
}
}
A.dp[a][a] = f[u][a] - (!a ? dp1[v][1] : min(dp1[v][0], dp1[v][1]));
return merge(A, B).get_min();
}
A.dp[a][a] = f[u][a] - (!a ? dp2[u][1] : min(dp2[u][0], dp2[u][1]));
int diff = dep[v] - dep[u];
for (int i = 0; i <= logn; i++) {
if (diff >> i & 1) {
B = merge(g[v][i][1], B);
v = fa[v][i];
}
}
if (fa[u][0] == fa[v][0]) goto next_part;
for (int i = logn; ~i; i--) {
if (fa[u][i] != fa[v][i]) {
A = merge(A, g[u][i][0]);
B = merge(g[v][i][1], B);
u = fa[u][i], v = fa[v][i];
}
}
next_part:;
node C; int x = fa[u][0];
C.dp[0][0] = f[x][0] - dp1[u][1] - dp1[v][1];
C.dp[1][1] = f[x][1] - min(dp1[u][0], dp1[u][1]) - min(dp1[v][0], dp1[v][1]);
return merge(A, merge(C, B)).get_min();
}
int main() {
scanf("%d %d %*s", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
for (int i = 1, u, v; i < n; i++) {
scanf("%d %d", &u, &v);
add_edge(u, v), add_edge(v, u);
}
dfs1(1, 0);
dfs2(1, 0);
/*
for (int i = 1; i <= n; i++) {
for (int j = 0; j < 2; j++) {
printf("%d %d %lld\n", i, j, f[i][j]);
}
}
*/
dfs3(1, 0);
for (int a, x, b, y; m--; ) {
scanf("%d %d %d %d", &a, &x, &b, &y);
llong ret = solve(a, x, b, y);
printf("%lld\n", ret == infl ? -1 : ret);
}
return 0;
}
/*
5 3 C3
2 4 1 3 9
1 5
5 2
5 3
3 4
1 0 3 0
2 1 3 1
1 0 5 0
5 1 C3
1 1 1 1 1
1 2
1 3
2 4
3 5
4 1 5 1
*/