据说这东西猫锟在WC2018讲过,怎么一点印象都没有呢?当时应该是在冬眠
NOIP2018T6如果用动态dp去看就是一道裸题,不过因为询问是相互独立的,即修改没有时效性,可以直接用倍增代替动态dp。
先看一道题:
P4719 【模板】动态dp
如果没有修改,这题就是树形dp入门题:没有上司的晚会
设 f ( x , 0 / 1 ) f(x,0/1) f(x,0/1)分别表示以x为根的子树中,选x的最大独立集,不选x的最大独立集。
则:
f ( x , 0 ) = ∑ y ∈ s o n ( x ) m a x ( f ( y , 0 ) , f ( y , 1 ) ) f(x,0)=\sum_{y∈son(x)}max(f(y,0),f(y,1)) f(x,0)=∑y∈son(x)max(f(y,0),f(y,1))
f ( x , 1 ) = v [ x ] + ∑ y ∈ s o n ( x ) f ( y , 0 ) f(x,1)=v[x]+\sum_{y∈son(x)}f(y,0) f(x,1)=v[x]+∑y∈son(x)f(y,0)
现在我们要修改一个点x的权值v,这个修改显然只会影响到x到root路径上所有点的f,问题在于如何做到快速修改。
既然是修改一条路径,我们可以想想树链剖分。
设 g ( x , 0 / 1 ) g(x,0/1) g(x,0/1),意义是x为根的子树中,把重儿子的子树删掉的x选(1)/不选(0)的最大独立集。
对于一条x到root路径的修改, g g g受到影响的显然只有x和每条重链顶的父亲。
由于重链只有log条,所以这里暴力修改即可。
那么问题在于如何快速进行重链上的dp转移,先写一下转移:
f ( i , 0 ) = g ( i , 0 ) + m a x ( f ( i + 1 , 0 ) , f ( i + 1 , 1 ) ) f(i,0)=g(i,0)+max(f(i+1,0),f(i+1,1)) f(i,0)=g(i,0)+max(f(i+1,0),f(i+1,1))
f ( i , 1 ) = g ( i , 1 ) + f ( i + 1 , 0 ) f(i,1)=g(i,1)+f(i+1,0) f(i,1)=g(i,1)+f(i+1,0)
为了方便,i和i+1就代表重链上相邻的两个点。
接下来要引入猫锟的奇技淫巧:
我们知道矩阵乘法,现在我们定义一种新的矩阵乘法
若 C = A ∗ B C=A*B C=A∗B
则 C ( i , j ) = m a x ( A ( i , k ) + B ( k , j ) ) ( 1 < = k < = n ) C(i,j)=max(A(i,k)+B(k,j))(1<=k<=n) C(i,j)=max(A(i,k)+B(k,j))(1<=k<=n)
容易证明该矩阵乘法也满足右结合律。
把dp转为矩阵乘法:
[ f ( i , 0 ) , f ( i , 1 ) ] = [ g ( i , 0 ) , g ( i , 1 ) ] ∗ [ f ( i + 1 , 0 ) , f ( i + 1 , 0 ) f ( i + 1 , 1 ) , − ∞ ] [f(i,0),f(i,1)]=[g(i,0),g(i,1)]* \bigl[ \begin{matrix} f(i+1,0),&f(i+1,0)\\ f(i+1,1),&-∞ \end{matrix} \bigr] [f(i,0),f(i,1)]=[g(i,0),g(i,1)]∗[f(i+1,0),f(i+1,1),f(i+1,0)−∞]
那么每个点的矩阵转为:
[ g ( x , 0 ) , g ( x , 0 ) g ( x , 1 ) + v [ x ] , − ∞ ] \bigl[ \begin{matrix} g(x,0),&g(x,0)\\ g(x,1)+v[x],&-∞ \end{matrix} \bigr] [g(x,0),g(x,1)+v[x],g(x,0)−∞]
用线段树维护区间矩阵乘积就可以快速得到一个点的 f f f了。
时间复杂度 O ( n l o g 2 n ) O(n~log^2n) O(n log2n)。
Code:
#include
#include
#define ll long long
#define max(a, b) ((a) > (b) ? (a) : (b))
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define FB for(int i = fi[x], y = to[i]; i; y = to[i = nt[i]])
#define i0 i << 1
#define i1 i << 1 | 1
#define pp printf
using namespace std;
const int N = 1e5 + 5;
int n, m, v[N];
int x, y, fi[N], to[N * 2], nt[N * 2], tot;
void link(int x, int y) { nt[++ tot] = fi[x], to[tot] = y, fi[x] = tot;}
int dep[N], fa[N], siz[N], son[N], top[N];
void dg(int x) {
siz[x] = 1; dep[x] = dep[fa[x]] + 1;
FB if(y != fa[x])
fa[y] = x, dg(y), siz[x] += siz[y],
son[x] = siz[son[x]] > son[y] ? son[x] : y;
}
int w[N], tw[N], cntw;
void dfs(int x) {
w[x] = tw[x] = ++ cntw;
if(son[x]) top[son[x]] = top[x], dfs(son[x]), tw[x] = tw[son[x]];
FB if(y != fa[x] && y != son[x])
top[y] = y, dfs(y);
}
const int inf = 2147483647;
struct jz {
ll a[2][2];
jz() { memset(a, 0, sizeof a);}
};
jz operator *(jz a, jz b) {
jz c;
c.a[0][0] = max(a.a[0][0] + b.a[0][0], a.a[0][1] + b.a[1][0]);
c.a[0][1] = max(a.a[0][0] + b.a[0][1], a.a[0][1] + b.a[1][1]);
c.a[1][0] = max(a.a[1][0] + b.a[0][0], a.a[1][1] + b.a[1][0]);
c.a[1][1] = max(a.a[1][0] + b.a[0][1], a.a[1][1] + b.a[1][1]);
return c;
}
int pl, pr, px; jz pj;
jz t[N * 4], f[N];
void add(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x == y) { t[i] = f[x]; return;}
int m = x + y >> 1; add(i0, x, m); add(i1, m + 1, y);
t[i] = t[i0] * t[i1];
}
void ft(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) { pj = pj * t[i]; return;}
int m = x + y >> 1; ft(i0, x, m); ft(i1, m + 1, y);
}
jz ask(int x) {
pj.a[0][0] = pj.a[1][1] = 0;
pj.a[0][1] = pj.a[1][0] = -inf;
pl = w[x], pr = tw[x]; ft(1, 1, n); return pj;
}
void xiu(int x, int lv, int nv) {
f[w[x]].a[1][0] += nv - lv;
jz la, nw;
while(x) {
la = ask(top[x]);
pl = pr = w[x]; add(1, 1, n);
nw = ask(top[x]);
x = fa[top[x]];
f[w[x]].a[0][0] += max(nw.a[0][0], nw.a[1][0]) - max(la.a[0][0], la.a[1][0]);
f[w[x]].a[0][1] = f[w[x]].a[0][0];
f[w[x]].a[1][0] += nw.a[0][0] - la.a[0][0];
}
}
int main() {
scanf("%d %d", &n, &m);
fo(i, 1, n) scanf("%d", &v[i]);
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
link(x, y); link(y, x);
}
dg(1); top[1] = 1; dfs(1);
fo(i, 1, n) f[i].a[1][1] = -inf;
fo(i, 1, n) xiu(i, 0, v[i]);
fo(ii, 1, m) {
scanf("%d %d", &x, &y);
xiu(x, v[x], y); v[x] = y;
pj = ask(1);
printf("%lld\n", max(pj.a[0][0], pj.a[1][0]));
}
}