测试地址:动态DP
做法: 本题需要用到树链剖分+线段树+矩阵乘法维护动态DP。
动态DP这个东西以前听过,但当时没有看懂,现在想来觉得是卡在矩阵乘法这个地方。这里用的不是传统的矩阵乘法。
一般的DP我们肯定会做,序列上的线性动态DP(可以用线性递推式递推的DP)很容易想到用线段树+矩阵乘法优化,但最大权值独立集这个经典树形DP模型要动态维护的话,有两个和上面问题不同的地方,第一是它不是序列,第二它的递推式有个 max \max max,这肯定不能用一般的矩阵乘法解决。下面我们来一一解决这些问题。
首先是把树上的问题转化为序列上的问题来做,显然可以想到树链剖分。根据经典的DP方程:
f ( i , 0 ) = ∑ max { f ( s o n , 0 ) , f ( s o n , 1 ) } f(i,0)=\sum \max\{f(son,0),f(son,1)\} f(i,0)=∑max{f(son,0),f(son,1)}
f ( i , 1 ) = v a l i + ∑ f ( s o n , 0 ) f(i,1)=val_i+\sum f(son,0) f(i,1)=vali+∑f(son,0)
而转移到序列上之后,一个点不在同一条重链上的其他儿子的贡献是一定的,我们把这些贡献记作 s ( i , 0 / 1 ) s(i,0/1) s(i,0/1),那么新的转移方程为:
f ( i , 0 ) = max { f ( s o n , 0 ) , f ( s o n , 1 ) } + s ( i , 0 ) f(i,0)=\max\{f(son,0),f(son,1)\}+s(i,0) f(i,0)=max{f(son,0),f(son,1)}+s(i,0)
f ( i , 1 ) = f ( s o n , 0 ) + s ( i , 1 ) f(i,1)=f(son,0)+s(i,1) f(i,1)=f(son,0)+s(i,1)
而在每次修改的时候,根据树链剖分的性质,最多有 O ( log n ) O(\log n) O(logn)条轻边,所以最多 O ( log n ) O(\log n) O(logn)个 s s s会改变,这个性质对接下来的讨论有很大帮助。
于是开始讨论第二个问题,如何加速转移?此时我们需要用一种奇特的矩阵乘法,一般的矩阵乘法是这样的:
c i , j = ∑ k a i , k b k , j c_{i,j}=\sum_ka_{i,k}b_{k,j} ci,j=∑kai,kbk,j
而这题需要用到的矩阵乘法是这样的:
c i , j = max k { a i , k + b k , j } c_{i,j}=\max_k\{a_{i,k}+b_{k,j}\} ci,j=maxk{ai,k+bk,j}
就是把加法变成 max \max max,乘法变成加法。我们发现这样的矩阵乘法和倍增+Floyd的那个合并方式完全一样,它具有和矩阵乘法一样的结合律,因此我们只要维护这样的矩阵乘法就可以了。具体转移矩阵的写法,我们把上面转移方程中 max \max max括号外面的 s ( i , 0 ) s(i,0) s(i,0)移到里面,分别加在两项中,就很显然是上面新型矩阵乘法的模式了。如果不希望从某个东西转移,在那个位置填一个 − i n f -inf −inf即可,具体的矩阵因为用latex写太麻烦我就不写了。而这样的矩阵乘法的单位矩阵是,主对角线是 0 0 0,其他位置都是 − i n f -inf −inf,证明显然。又根据上面的结论,转移矩阵每次最多有 O ( log n ) O(\log n) O(logn)个改变,因此用线段树维护单点修改即可,这样我们就以 O ( 8 n log 2 n ) O(8n\log^2n) O(8nlog2n)的时间复杂度解决了这一题。
以下是本人代码:
#include
using namespace std;
typedef long long ll;
const ll inf=1000000000ll*1000000000ll;
int n,m,first[100010],tot=0;
int son[100010],fa[100010],top[100010],bot[100010],siz[100010];
int pos[100010],qpos[100010],tim=0;
ll val[100010],f[100010][2],s[100010][2];
struct edge
{
int v,next;
}e[200010];
struct matrix
{
ll s[2][2];
}seg[400010],Ans,E,C;
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
void dfs1(int v)
{
f[v][0]=0,f[v][1]=val[v];
son[v]=0;siz[v]=1;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v])
{
fa[e[i].v]=v;
dfs1(e[i].v);
f[v][0]+=max(f[e[i].v][0],f[e[i].v][1]);
f[v][1]+=f[e[i].v][0];
siz[v]+=siz[e[i].v];
if (siz[e[i].v]>siz[son[v]])
son[v]=e[i].v;
}
}
void dfs2(int v,int tp)
{
top[v]=tp;
pos[v]=++tim,qpos[tim]=v;
if (son[v]) dfs2(son[v],tp),bot[v]=bot[son[v]];
else bot[v]=v;
s[v][0]=f[v][0]-max(f[son[v]][0],f[son[v]][1]);
s[v][1]=f[v][1]-f[son[v]][0];
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=fa[v]&&e[i].v!=son[v])
dfs2(e[i].v,e[i].v);
}
void Mult(matrix &S,matrix A,matrix B)
{
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
{
S.s[i][j]=-inf;
for(int k=0;k<2;k++)
S.s[i][j]=max(S.s[i][j],A.s[i][k]+B.s[k][j]);
}
}
void pushup(int no)
{
Mult(seg[no],seg[no<<1],seg[no<<1|1]);
}
void buildtree(int no,int l,int r)
{
if (l==r)
{
seg[no].s[0][0]=seg[no].s[0][1]=s[qpos[l]][0];
seg[no].s[1][0]=s[qpos[l]][1];
seg[no].s[1][1]=-inf;
return;
}
int mid=(l+r)>>1;
buildtree(no<<1,l,mid);
buildtree(no<<1|1,mid+1,r);
pushup(no);
}
void modify(int no,int l,int r,int x)
{
if (l==r)
{
seg[no]=C;
return;
}
int mid=(l+r)>>1;
if (x<=mid) modify(no<<1,l,mid,x);
else modify(no<<1|1,mid+1,r,x);
pushup(no);
}
void query(int no,int l,int r,int s,int t)
{
if (l>=s&&r<=t)
{
Mult(Ans,Ans,seg[no]);
return;
}
int mid=(l+r)>>1;
if (s<=mid) query(no<<1,l,mid,s,t);
if (t>mid) query(no<<1|1,mid+1,r,s,t);
}
void Modify(int x,ll v)
{
ll last0,last1;
Ans=E;
query(1,1,n,pos[top[x]],pos[bot[x]]);
last0=max(Ans.s[0][0],Ans.s[0][1]);
last1=max(Ans.s[1][0],Ans.s[1][1]);
s[x][1]+=v-val[x];
C.s[1][0]=s[x][1];
val[x]=v;
C.s[0][0]=C.s[0][1]=s[x][0];
C.s[1][0]=s[x][1];
C.s[1][1]=-inf;
modify(1,1,n,pos[x]);
x=top[x];
while(x!=1)
{
int y=fa[x];
Ans=E;
query(1,1,n,pos[x],pos[bot[x]]);
ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
s[y][0]+=max(ans0,ans1)-max(last0,last1);
s[y][1]+=ans0-last0;
C.s[0][0]=C.s[0][1]=s[y][0];
C.s[1][0]=s[y][1];
C.s[1][1]=-inf;
Ans=E;
query(1,1,n,pos[top[y]],pos[bot[y]]);
last0=max(Ans.s[0][0],Ans.s[0][1]);
last1=max(Ans.s[1][0],Ans.s[1][1]);
modify(1,1,n,pos[y]);
x=top[y];
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%lld",&val[i]);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
insert(a,b),insert(b,a);
}
fa[1]=siz[0]=0;
dfs1(1);
f[0][0]=f[0][1]=0;
dfs2(1,1);
buildtree(1,1,n);
E.s[1][0]=E.s[0][1]=-inf;
for(int i=1;i<=m;i++)
{
int x;ll y;
scanf("%d%lld",&x,&y);
Modify(x,y);
Ans=E;
query(1,1,n,pos[1],pos[bot[1]]);
ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
printf("%lld\n",max(ans0,ans1));
}
return 0;
}