CSP-S 2019 D2T3 树的重心 题解

有参考洛谷题解里面 ren482933891 所作的题解。

题意简述

给定一棵树,枚举删每条边,对于每种删边方案求出两棵新树的重心,求所有删边方案的重心编号和之和。

题解 O ( n log ⁡ n ) O(n\log n) O(nlogn)

(重儿子,重链定义同树剖)

考虑基本定理:

设一棵树以 x x x 为根,那么它的重心要么是 x x x,要么在以 x x x 的重儿子为根的子树中。

那么容易有推广:

设一棵树以 x x x 为根,那么它的重心要么是 x x x,要么在 x x x 连出的重链上。

我们发现,如果沿着一个点 x x x 沿重链往下找,那么遇到的第一个 p p p 使得 siz p ≤ siz x 2 \text{siz}_p \leq \frac{\text{siz}_x}{2} sizp2sizx 就是重心。

另外地,如果 siz p = siz x 2 \text{siz}_p = \frac{\text{siz}_x}{2} sizp=2sizx,那么 fa p \text{fa}_p fap 也是一个可能的选择,需要加以判断。

接下来我们需要知道怎么快速找到这个 p p p。很明显,重链上倍增向下跳即可。

然后我们回到要求的问题上来。

设割断了 ( u , v ) (u,v) (u,v),且不失一般性设 v v v 深度更大。

那么我们就要求 原树 - subtree v \text{subtree}_v subtreev 的重心和 subtree v \text{subtree}_v subtreev 的重心。

后者刚刚已经说过可以倍增求出。对于前者,我们发现可能重链的剖分出现了些变化。

我们从 u u u 倍增向上跳到第一个由于它的影响使得重儿子改变的点 t t t,修改一下重链走向。

然后,分两段跳,第一段为根到 t t t 的重链,第二段为 t t t 起头的变向重链,倍增向下跳寻找重心即可。

为了查询是否影响及修改重链走向,我们需要存子树次大的儿子,以及相关联的倍增信息。

时间复杂度 O ( n log ⁡ n ) O(n\log n) O(nlogn)

一个 O ( n ) O(n) O(n) 的做法

我们先一遍 dp 求出原树重心,然后以重心为根。

对于每次割边 ( u , v ) (u,v) (u,v),我们将 subtree v \text{subtree}_v subtreev 的重心定为内重心,原树 - subtree v \text{subtree}_v subtreev 的重心定为外重心

先处理内重心

我们用 DFS 的方式,按栈序枚举删边。对于删 ( u , v ) (u,v) (u,v) 的情况,我们一定已经计算出了 v v v 的所有子树的重心。

那么,我们依次考虑 v v v 的每棵子树,将它并进来计算新的重心。发现合并一棵子树等同于用一条边把两棵树相连(一棵是 v v v 及之前所有子树的并,一棵是当前子树)。

对于这种情况,我们选取较大的一棵树的重心。然后可以发现,合并后的重心一定是这个重心向上跳出来的。我们将这个重心沿重链暴力上跳,判断是否满足条件即可。

由于对于每个点,最多被上跳到它的儿子数 +1 次,所以复杂度 O ( n ) O(n) O(n)

然后处理外重心

在处理之前,我们用 DFS 的方式考虑删边,但是我们先按顺序不是栈序,而是队列序来考虑。

如果我们计算出了删 ( u , v ) (u,v) (u,v) 所得的外重心,然后去计算删掉 ( v , son v ) (v,\text{son}_v) (v,sonv) 所得外重心,容易发现外重心只会上跳。

同时,又由于根是原树重心,所以跳到根以后必然不会下跳。

因此,我们只需要处理出对于割断每条 ( r t , son r t ) (rt,\text{son}_{rt}) (rt,sonrt) 的情况下的外重心,就能对每种情况都找到一条重心路径,使得在处理 son r t \text{son}_{rt} sonrt 为根的子树时,重心必然在此路径上移动。处理时实际上只需要求出根节点最大和次大的子树即可,

然后,我们再次考虑用 DFS 栈序来枚举删边。对于删 ( u , v ) (u,v) (u,v) 的情况,我们一定已经计算出了割断 v v v 及其儿子的所有情况的外重心

删除 ( u , v ) (u,v) (u,v) 所得外重心,必然在重心路径上,且比割断 v v v 及其儿子的所有情况的外重心更深。

我们从外重心最深的一个情况转移,然后暴力在重心路径上下跳,判断是否满足条件即可。时间复杂度 O ( n ) O(n) O(n)

代码: O ( n ) O(n) O(n)

#include 
#include 
#include 
using namespace std;
const int N=3e5+1;
char buf[1<<20];
int ppp,qqq;
inline char gc()
{
	if(ppp==qqq) ppp=0,qqq=fread(buf,1,1<<20,stdin);
	return buf[ppp++];
}
inline int rd()
{
	int t=0;char x;
	while((x=gc())<'0'||x>'9');t=x-'0';
	while((x=gc())>='0'&&x<='9') t=t*10+x-'0';
	return t; 
}
int bg[N],nx[N*2],to[N*2],tl;
inline void add(int x,int y)
{
	nx[++tl]=bg[x];
	bg[x]=tl;
	to[tl]=y;
}
long long ans;
int n,ctr,siz[N];
void dfs1(int now,int f)
{
	siz[now]=1;
	int is=1;
	for(int i=bg[now];i;i=nx[i])
	{
		int aim=to[i];
		if(aim==f) continue;
		dfs1(aim,now);
		if(siz[aim]>n/2) is=0;
		siz[now]+=siz[aim];
	}
	if(n-siz[now]>n/2) is=0;
	if(is) ctr=now;
}
int mx[N],sec,dep[N],fa[N],inp[N];
void dfs2(int now,int f,int depth)
{
	siz[now]=1,dep[now]=depth;
	fa[now]=f,inp[now]=now;
	for(int i=bg[now];i;i=nx[i])
	{
		int aim=to[i];
		if(aim==f) continue;
		dfs2(aim,now,depth+1);
		if(siz[now]<=siz[aim]) inp[now]=inp[aim];
		siz[now]+=siz[aim];
		while(siz[now]-siz[inp[now]]>siz[now]/2) inp[now]=fa[inp[now]];
		if(!f)
		{
			if(siz[aim]>siz[mx[now]]) sec=mx[now],mx[now]=aim;
			else if(siz[aim]>siz[sec]) sec=aim;
		}
		else if(siz[aim]>siz[mx[now]]) mx[now]=aim;
	}
	if(now==ctr) return;
	ans+=inp[now];
	if(siz[inp[now]]==siz[now]/2) ans+=fa[inp[now]];
}
int outp[N];
void dfs3(int now,int typ)
{
	outp[now]=ctr;
	for(int i=bg[now];i;i=nx[i])
	{
		int aim=to[i];
		if(aim==fa[now]) continue;
		dfs3(aim,typ);
		if(dep[outp[aim]]>dep[outp[now]]) outp[now]=outp[aim];
	}
	if(outp[now]==ctr)
	{
		if(typ&&siz[sec]>=(n-siz[now])/2&&n-siz[sec]-siz[now]<=(n-siz[now])/2) outp[now]=sec;
		else if(!typ&&siz[mx[ctr]]>=(n-siz[now])/2&&n-siz[mx[ctr]]-siz[now]<=(n-siz[now])/2) outp[now]=mx[ctr];
	}
	if(outp[now]!=ctr)
	{
		while(siz[mx[outp[now]]]>=(n-siz[now])/2&&n-siz[now]-siz[mx[outp[now]]]<=(n-siz[now])/2) outp[now]=mx[outp[now]];
	}
	ans+=outp[now];
	if(siz[outp[now]]==(n-siz[now])/2) ans+=fa[outp[now]];
}
void reset()
{
	tl=ctr=sec=ans=0;
	memset(bg,0,sizeof(bg));
	memset(mx,0,sizeof(mx));
	memset(inp,0,sizeof(inp));
	memset(outp,0,sizeof(outp));
}
int main()
{
	int T=rd();
	while(T--)
	{
		reset();
		int i;
		n=rd();
		for(i=1;i<n;i++)
		{
			int x=rd(),y=rd();
			add(x,y),add(y,x);
		}
		dfs1(1,0);
		dfs2(ctr,0,1);
		for(i=bg[ctr];i;i=nx[i])
		{
			int aim=to[i];
			if(aim==mx[ctr]) dfs3(aim,1);
			else dfs3(aim,0);
		}
		printf("%lld\n",ans);
	}
	return 0;
}

你可能感兴趣的:(OI)