一道模板题
动态 d p dp dp是猫学长发明的用来解决树上带修DP的问题的算法。
好像多数是求树上最大权独立集?
树上最大权独立集我们可以用树形 d p O ( n ) dpO(n) dpO(n)地求出来,设 f [ u ] [ 0 / 1 ] f[u][0/1] f[u][0/1]表示 u u u为根的子树 u u u选或不选的最优方案,可以列出转移式:
f [ u ] [ 0 ] + = m a x ( f [ v ] [ 1 ] , f [ v ] [ 0 ] ) f[u][0]+=max(f[v][1],f[v][0]) f[u][0]+=max(f[v][1],f[v][0])
f [ u ] [ 1 ] + = f [ v ] [ 0 ] f[u][1]+=f[v][0] f[u][1]+=f[v][0]
如果带修改该如何做呢,这个时候就要用树剖+线段树+矩阵来解决了。
树剖+线段树是常用的解决树上问题的优秀算法,想想重链一定是一个区间,就可以用线段树来维护,那么每个节点只需要维护其他轻儿子,可以把上面的转移式改一改:
f [ u ] [ 0 ] = ∑ v ∈ l i g h t s o n ( u ) m a x ( f [ v ] [ 1 ] , f [ v ] [ 0 ] ) f[u][0]=\sum_{v\in lightson(u)}max(f[v][1],f[v][0]) f[u][0]=∑v∈lightson(u)max(f[v][1],f[v][0])
f [ u ] [ 1 ] = a [ u ] + ∑ v ∈ l i g h t s o n ( u ) f [ v ] [ 0 ] f[u][1]=a[u]+\sum_{v\in lightson(u)}f[v][0] f[u][1]=a[u]+∑v∈lightson(u)f[v][0]
把新的 f f f记为 g g g
这样就实现了维护一棵树到维护一个序列的转变
这个转移怎么转化成矩阵的形式?
我们可以改变一下矩阵乘法的运算:
Mat operator *(const Mat &x) const{
Mat ret;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
ret.g[i][j]=max(ret.g[i][j],g[i][k]+x.g[k][j]);
return ret;
}
然后想想矩阵是什么:
[ f i , 0 f i , 1 ] = [ f i − 1 , 0 f i − 1 , 1 ] × [ g i , 0 g i , 0 g i , 1 − ∞ ] \begin{bmatrix}f_{i,0}\\f_{i,1}\end{bmatrix} =\begin{bmatrix}f_{i-1,0}\\f_{i-1,1}\end{bmatrix}\times \begin{bmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&-\infty \end{bmatrix} [fi,0fi,1]=[fi−1,0fi−1,1]×[gi,0gi,1gi,0−∞]
然后代入矩阵乘法算一算,顿时觉得很有道理啊
于是可以用线段树维护每个区间的矩阵积,查询的时候只需要在线段树上 q u e r y query query就好了
然后修改怎么做呢?可以发现当一个节点的权值改变了,那么它的链顶的节点的 f f f值会改变,再往上,每过一条轻边,它都会影响那个点的 g g g值,在线段树上单点修改就好了,因为轻边是 l o g n logn logn级别的,线段树修改也是 l o g n logn logn级别的,所以修改复杂度是 n l o g 2 n nlog^2n nlog2n的,查询 n l o g n nlogn nlogn,总复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
代码如下:
#include
#include
#include
#include
#include
#define LL long long
#define N 100005
#define ls cur<<1
#define rs cur<<1|1
#define inf 0x3f3f3f3f
using namespace std;
inline int rd(){
int x=0,f=1;char c=' ';
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
return x*f;
}
int n,m,a[N],cnt,head[N],f[N][2];
int dfn[N],rk[N],dep[N],fa[N],son[N],siz[N],top[N],ed[N],num;
struct EDGE{
int to,nxt;
}edge[N<<1];
inline void add(int x,int y){
edge[++cnt].to=y; edge[cnt].nxt=head[x]; head[x]=cnt;
}
void dfs1(int u,int fat){
siz[u]=1; int maxson=-1;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to; if(v==fat) continue; fa[v]=u;
dep[v]=dep[u]+1; dfs1(v,u); siz[u]+=siz[v];
if(siz[v]>maxson) maxson=siz[v],son[u]=v;
} return;
}
void dfs2(int u,int t){
top[u]=t; dfn[u]=++num; rk[num]=u; ed[t]=u;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to;
if(!dfn[v]) dfs2(v,v);
} return;
}
struct Mat{
int g[2][2];
Mat(){memset(g,0,sizeof g);}
Mat operator *(const Mat &x) const{
Mat ret;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
ret.g[i][j]=max(ret.g[i][j],g[i][k]+x.g[k][j]);
return ret;
}
}val[N],node[N<<2];
void build(int cur,int L,int R){
if(L==R){
int g0=0,g1=a[rk[L]];
for(int u=rk[L],i=head[u],v;i;i=edge[i].nxt)
if((v=edge[i].to)!=fa[u] && v!=son[u])
g0+=max(f[v][0],f[v][1]),g1+=f[v][0];
node[cur].g[0][0]=node[cur].g[0][1]=g0;
node[cur].g[1][0]=g1; node[cur].g[1][1]=-inf;
val[L]=node[cur]; return;
}
int mid=(L+R)>>1;
build(ls,L,mid); build(rs,mid+1,R);
node[cur]=node[ls]*node[rs];
}
void update(int cur,int L,int R,int p){
if(L==R) {node[cur]=val[L];return;}
int mid=(L+R)>>1;
if(p<=mid) update(ls,L,mid,p);
else update(rs,mid+1,R,p);
node[cur]=node[ls]*node[rs];
}
Mat query(int cur,int L,int R,int ql,int qr){
if(ql<=L && qr>=R) return node[cur];
int mid=(L+R)>>1;
if(qr<=mid) return query(ls,L,mid,ql,qr);
if(ql>mid) return query(rs,mid+1,R,ql,qr);
return query(ls,L,mid,ql,qr)*query(rs,mid+1,R,ql,qr);
}
inline Mat ask(int u){
return query(1,1,n,dfn[top[u]],dfn[ed[top[u]]]);
}
inline void change(int u,int x){
val[dfn[u]].g[1][0]+=x-a[u]; a[u]=x;
Mat pre,now;
while(u){
pre=ask(top[u]);
update(1,1,n,dfn[u]);
now=ask(top[u]);
u=fa[top[u]];
val[dfn[u]].g[0][0]+=max(now.g[0][0],now.g[1][0])-max(pre.g[0][0],pre.g[1][0]);
val[dfn[u]].g[0][1]=val[dfn[u]].g[0][0];
val[dfn[u]].g[1][0]+=now.g[0][0]-pre.g[0][0];
}
}
void DP(int u,int fat){
f[u][1]=a[u];
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].to; if(v==fat) continue;
DP(v,u);
f[u][0]+=max(f[v][1],f[v][0]);
f[u][1]+=f[v][0];
} return;
}
inline void prework(){
dfs1(1,0); dfs2(1,1);
DP(1,0); build(1,1,n);
}
int main(){
n=rd();m=rd();
for(int i=1;i<=n;i++) a[i]=rd();
for(int i=1;i<n;i++){
int x=rd(),y=rd();
add(x,y); add(y,x);
}
prework(); int x,y; Mat ans;
while(m--){
x=rd(),y=rd();
change(x,y);
ans=ask(1);
printf("%d\n",max(ans.g[0][0],ans.g[1][0]));
}
return 0;
}