点分治讲解

介绍
这里的点分治用于对树上点之间的信息处理。其主要是利用分治思想以及容斥原理。
我们考虑一棵树,需要统计所有节点对的信息。直接枚举的话,时间复杂度至少是 O ( N 2 ) O(N^2) O(N2)的。而点分治可以把这个复杂度降到 O ( N l o g N ) O(NlogN) O(NlogN)
思想
我们考虑选取一个树根, O ( n ) O(n) On地求出所有到这个树根的路径的组合,其中n是这棵树的大小,再删去这个树根,把树拆成几棵子树,递归做下去。
假设每次都能选取一个理想的树根,那么对于任意一个点,考虑最坏情况每次只能拆出两棵子树,则每次子树的规模减小一半,总共需要拆解 O ( l o g N ) O(logN) OlogN次就能保证子树中仅有这个点。所以每个点最多存在于 l o g N logN logN棵子树中。那么在所有的子树中总点数不超过 N l o g N NlogN NlogN,而求出每棵子树的路径组合的时间复杂度仅与子树大小相关,那么可以保证总的时间复杂度也是 O ( N l o g N ) O(NlogN) O(NlogN)的。

以BZOJ 2152聪聪可可为例。
讲解
题面要求求一棵树上所有长为3的倍数的路径数。
那么我们对于一棵子树,可以 O ( n ) O(n) O(n)就是遍历地求出所有点到根节点的距离,那么这棵子树中满足条件的路径条数就是(距离模3余2的点数)(距离模3余1的点数)2+(距离模3余0的点数)(距离模3余0的点数)。(考虑点对可以互换位置,所以对于两个距离都不是3的倍数的点的组合需要2)。
相应的,我们可以写出这样一个函数:其中dfs()用于遍历子树同时向dis[]中记入距离。dis[]用于记录距离。

long long cal(int root,int length)
{
	cz[0]=cz[1]=cz[2]=0;
	dis[root]=length;
	cz[dis[root]%3]++;
	dfs(root,0);
	return (cz[1]*cz[2]*2)+cz[0]*(cz[0]);
}

然而,上面的对于路径求解的说明并不全对。考虑这样的情况,我们假设有这样u->v,v->k,v->j,且(|v->k|+|u->v|+|u->v|+|v->j|)%3=0,那么在返回值中就会错误地多记入路径。而实际上j到k的路径应为|v->j|+|v->k|。
这就是cal()需要第二个参数的原因。让我们来看看调用cal()的函数:

void calc(int now)
{
	ans+=cal(root,0);
	vis[root]=1;
	for(int i=h[root];i;i=nxt[i])if(!vis[p[i]])
	{
		MX=1e9;
		int v=p[i];
		ans-=cal(v,length[i]);
		calc_size(v,now);
		size=siz[v];
		find_root(v,now);
		calc(v);
	}
}

这一段是分治部分的函数。now是预先计算出的当前子树的树根,find_root()用于寻找树根。其中cal()用于计算某棵子树的所有路径条数。vis数组用于标记访问过的点,同时也在cal()中用于判断子树的分割边界。calc_size()函数用于重新计算某个点对应的子树大小。
注意到ans+=cal(root,0); 和对于下面遍历每个子节点时的ans-=cal(v,length[i]);
其中ans+=cal(root,0);是用于计算当前子树的所有可能路径,而ans-=cal(v,length[i]);是用于去除当前树根下的所有重复路径。
我们需要想办法减去错误合并的路径。这里解决办法就是ans-=cal(v,length[i]);相当于在v对应的子树中重新找到这条路径,然后减去它。
显然,在ans+=cal(root,0);ans-=cal(v,length[i]);中dis[v]的值都相同,进而使得属于v的子树的点的dis值都相同,所以cal(v,length[i])返回的就是所有在v的子树中被错误合并的路径条数。

下面附上这题的完整代码:

#include
using namespace std;
typedef long long ll;
const int maxn=1e6+5;
int p[maxn],h[maxn],nxt[maxn],length[maxn],dis[maxn],cz[5];
int siz[maxn],vis[maxn],q[maxn];
int root,n,size,MX; 
long long ans;
int tot=1;
inline void addedge(int a,int b,int c)
{
	p[tot]=b;nxt[tot]=h[a];length[tot]=c;h[a]=tot++;
}
void find_root(int now,int f)
{
	int mi=0;
	mi=size-siz[now];
	for(int i=h[now];i;i=nxt[i])
	if(!vis[p[i]]&&p[i]!=f)
	{	
		find_root(p[i],now);
		mi=max(mi,siz[p[i]]);
	}
	if(MX>mi) 
	{
		MX=mi;root=now;
	} 
	return ;
}
void calc_size(int now,int f)
{
	siz[now]=1;
	for(int i=h[now];i;i=nxt[i])
	{
		if(p[i]!=f&&!vis[p[i]])
		{
			calc_size(p[i],now);
			siz[now]+=siz[p[i]];
		}
	}
	return ;
}
void dfs(int now,int fa)
{
	for(int i=h[now];i;i=nxt[i])
	{
		if(!vis[p[i]]&&p[i]!=fa)
		{
			dis[p[i]]=dis[now]+length[i];
			cz[dis[p[i]]%3]++;
			dfs(p[i],now);
		}
	}
	return;
}
long long cal(int root,int length)
{
	cz[0]=cz[1]=cz[2]=0;
	dis[root]=length;
	cz[dis[root]%3]++;
	dfs(root,0);
	return (cz[1]*cz[2]*2)+cz[0]*(cz[0]);
}
void calc(int now)
{
	ans+=cal(root,0);
	vis[root]=1;
	for(int i=h[root];i;i=nxt[i])if(!vis[p[i]])
	{
		MX=1e9;
		int v=p[i];
		ans-=cal(v,length[i]);
		calc_size(v,now);
		size=siz[v];
		find_root(v,now);
		calc(v);
	}
}
long long gcd(long long a,long long b){return (!a)?b:gcd(b%a,a);}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<n;i++)
	{
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		addedge(x,y,w);addedge(y,x,w);
	}
	MX=1e9;
	calc_size(1,0);
	size=siz[1];
	find_root(1,0);
	calc(1);
	long long pz=gcd(ans,(long long)n*n);
	printf("%lld/%lld",ans/pz,(long long)n*n/pz);
	return 0;
}

你可能感兴趣的:(点分治)