【LuoguP4719】动态DP模板-树链剖分+线段树+矩阵乘法

测试地址:动态DP
做法: 本题需要用到树链剖分+线段树+矩阵乘法维护动态DP。
动态DP这个东西以前听过,但当时没有看懂,现在想来觉得是卡在矩阵乘法这个地方。这里用的不是传统的矩阵乘法。
一般的DP我们肯定会做,序列上的线性动态DP(可以用线性递推式递推的DP)很容易想到用线段树+矩阵乘法优化,但最大权值独立集这个经典树形DP模型要动态维护的话,有两个和上面问题不同的地方,第一是它不是序列,第二它的递推式有个 max ⁡ \max max,这肯定不能用一般的矩阵乘法解决。下面我们来一一解决这些问题。
首先是把树上的问题转化为序列上的问题来做,显然可以想到树链剖分。根据经典的DP方程:
f ( i , 0 ) = ∑ max ⁡ { f ( s o n , 0 ) , f ( s o n , 1 ) } f(i,0)=\sum \max\{f(son,0),f(son,1)\} f(i,0)=max{f(son,0),f(son,1)}
f ( i , 1 ) = v a l i + ∑ f ( s o n , 0 ) f(i,1)=val_i+\sum f(son,0) f(i,1)=vali+f(son,0)
而转移到序列上之后,一个点不在同一条重链上的其他儿子的贡献是一定的,我们把这些贡献记作 s ( i , 0 / 1 ) s(i,0/1) s(i,0/1),那么新的转移方程为:
f ( i , 0 ) = max ⁡ { f ( s o n , 0 ) , f ( s o n , 1 ) } + s ( i , 0 ) f(i,0)=\max\{f(son,0),f(son,1)\}+s(i,0) f(i,0)=max{f(son,0),f(son,1)}+s(i,0)
f ( i , 1 ) = f ( s o n , 0 ) + s ( i , 1 ) f(i,1)=f(son,0)+s(i,1) f(i,1)=f(son,0)+s(i,1)
而在每次修改的时候,根据树链剖分的性质,最多有 O ( log ⁡ n ) O(\log n) O(logn)条轻边,所以最多 O ( log ⁡ n ) O(\log n) O(logn) s s s会改变,这个性质对接下来的讨论有很大帮助。
于是开始讨论第二个问题,如何加速转移?此时我们需要用一种奇特的矩阵乘法,一般的矩阵乘法是这样的:
c i , j = ∑ k a i , k b k , j c_{i,j}=\sum_ka_{i,k}b_{k,j} ci,j=kai,kbk,j
而这题需要用到的矩阵乘法是这样的:
c i , j = max ⁡ k { a i , k + b k , j } c_{i,j}=\max_k\{a_{i,k}+b_{k,j}\} ci,j=maxk{ai,k+bk,j}
就是把加法变成 max ⁡ \max max,乘法变成加法。我们发现这样的矩阵乘法和倍增+Floyd的那个合并方式完全一样,它具有和矩阵乘法一样的结合律,因此我们只要维护这样的矩阵乘法就可以了。具体转移矩阵的写法,我们把上面转移方程中 max ⁡ \max max括号外面的 s ( i , 0 ) s(i,0) s(i,0)移到里面,分别加在两项中,就很显然是上面新型矩阵乘法的模式了。如果不希望从某个东西转移,在那个位置填一个 − i n f -inf inf即可,具体的矩阵因为用latex写太麻烦我就不写了。而这样的矩阵乘法的单位矩阵是,主对角线是 0 0 0,其他位置都是 − i n f -inf inf,证明显然。又根据上面的结论,转移矩阵每次最多有 O ( log ⁡ n ) O(\log n) O(logn)个改变,因此用线段树维护单点修改即可,这样我们就以 O ( 8 n log ⁡ 2 n ) O(8n\log^2n) O(8nlog2n)的时间复杂度解决了这一题。
以下是本人代码:

#include 
using namespace std;
typedef long long ll;
const ll inf=1000000000ll*1000000000ll;
int n,m,first[100010],tot=0;
int son[100010],fa[100010],top[100010],bot[100010],siz[100010];
int pos[100010],qpos[100010],tim=0;
ll val[100010],f[100010][2],s[100010][2];
struct edge
{
	int v,next;
}e[200010];
struct matrix
{
	ll s[2][2];
}seg[400010],Ans,E,C;

void insert(int a,int b)
{
	e[++tot].v=b;
	e[tot].next=first[a];
	first[a]=tot;
}

void dfs1(int v)
{
	f[v][0]=0,f[v][1]=val[v];
	son[v]=0;siz[v]=1;
	for(int i=first[v];i;i=e[i].next)
		if (e[i].v!=fa[v])
		{
			fa[e[i].v]=v;
			dfs1(e[i].v);
			f[v][0]+=max(f[e[i].v][0],f[e[i].v][1]);
			f[v][1]+=f[e[i].v][0];
			siz[v]+=siz[e[i].v];
			if (siz[e[i].v]>siz[son[v]])
				son[v]=e[i].v;
		}
}

void dfs2(int v,int tp)
{
	top[v]=tp;
	pos[v]=++tim,qpos[tim]=v;
	if (son[v]) dfs2(son[v],tp),bot[v]=bot[son[v]];
	else bot[v]=v;
	s[v][0]=f[v][0]-max(f[son[v]][0],f[son[v]][1]);
	s[v][1]=f[v][1]-f[son[v]][0];
	for(int i=first[v];i;i=e[i].next)
		if (e[i].v!=fa[v]&&e[i].v!=son[v])
			dfs2(e[i].v,e[i].v);
}

void Mult(matrix &S,matrix A,matrix B)
{
	for(int i=0;i<2;i++)
		for(int j=0;j<2;j++)
		{
			S.s[i][j]=-inf;
			for(int k=0;k<2;k++)
				S.s[i][j]=max(S.s[i][j],A.s[i][k]+B.s[k][j]);
		}
}

void pushup(int no)
{
	Mult(seg[no],seg[no<<1],seg[no<<1|1]);
}

void buildtree(int no,int l,int r)
{
	if (l==r)
	{
		seg[no].s[0][0]=seg[no].s[0][1]=s[qpos[l]][0];
		seg[no].s[1][0]=s[qpos[l]][1];
		seg[no].s[1][1]=-inf;
		return;
	}
	int mid=(l+r)>>1;
	buildtree(no<<1,l,mid);
	buildtree(no<<1|1,mid+1,r);
	pushup(no);
}

void modify(int no,int l,int r,int x)
{
	if (l==r)
	{
		seg[no]=C;
		return;
	}
	int mid=(l+r)>>1;
	if (x<=mid) modify(no<<1,l,mid,x);
	else modify(no<<1|1,mid+1,r,x);
	pushup(no);
}

void query(int no,int l,int r,int s,int t)
{
	if (l>=s&&r<=t)
	{
		Mult(Ans,Ans,seg[no]);
		return;
	}
	int mid=(l+r)>>1;
	if (s<=mid) query(no<<1,l,mid,s,t);
	if (t>mid) query(no<<1|1,mid+1,r,s,t);
}

void Modify(int x,ll v)
{
	ll last0,last1;
	Ans=E;
	query(1,1,n,pos[top[x]],pos[bot[x]]);
	last0=max(Ans.s[0][0],Ans.s[0][1]);
	last1=max(Ans.s[1][0],Ans.s[1][1]);
	
	s[x][1]+=v-val[x];
	C.s[1][0]=s[x][1];
	val[x]=v;
	C.s[0][0]=C.s[0][1]=s[x][0];
	C.s[1][0]=s[x][1];
	C.s[1][1]=-inf;
	
	modify(1,1,n,pos[x]);
	x=top[x];
	while(x!=1)
	{
		int y=fa[x];
		Ans=E;
		query(1,1,n,pos[x],pos[bot[x]]);
		ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
		ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
		s[y][0]+=max(ans0,ans1)-max(last0,last1);
		s[y][1]+=ans0-last0;
		C.s[0][0]=C.s[0][1]=s[y][0];
		C.s[1][0]=s[y][1];
		C.s[1][1]=-inf;
		
		Ans=E;
		query(1,1,n,pos[top[y]],pos[bot[y]]);
		last0=max(Ans.s[0][0],Ans.s[0][1]);
		last1=max(Ans.s[1][0],Ans.s[1][1]);
		
		modify(1,1,n,pos[y]);
		x=top[y];
	}
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
		scanf("%lld",&val[i]);
	for(int i=1;i<n;i++)
	{
		int a,b;
		scanf("%d%d",&a,&b);
		insert(a,b),insert(b,a);
	}
	
	fa[1]=siz[0]=0;
	dfs1(1);
	f[0][0]=f[0][1]=0;
	dfs2(1,1);
	buildtree(1,1,n);
	
	E.s[1][0]=E.s[0][1]=-inf;
	for(int i=1;i<=m;i++)
	{
		int x;ll y;
		scanf("%d%lld",&x,&y);
		Modify(x,y);
		Ans=E;
		query(1,1,n,pos[1],pos[bot[1]]);
		ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
		ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
		printf("%lld\n",max(ans0,ans1));
	}
	
	return 0;
}

你可能感兴趣的:(数据结构-线段树,算法-树链剖分,数学-线性代数)