[BZOJ 2152]聪聪可可(点分治)

Description
聪聪和可可是兄弟俩,他们俩经常为了一些琐事打起来,例如家中只剩下最后一根冰棍而两人都想吃、两个人都想玩儿电脑(可是他们家只有一台电脑)……遇到这种问题,一般情况下石头剪刀布就好了,可是他们已经玩儿腻了这种低智商的游戏。他们的爸爸快被他们的争吵烦死了,所以他发明了一个新游戏:由爸爸在纸上画n个“点”,并用n-1条“边”把这n个“点”恰好连通(其实这就是一棵树)。并且每条“边”上都有一个数。接下来由聪聪和可可分别随即选一个点(当然他们选点时是看不到这棵树的),如果两个点之间所有边上数的和加起来恰好是3的倍数,则判聪聪赢,否则可可赢。聪聪非常爱思考问题,在每次游戏后都会仔细研究这棵树,希望知道对于这张图自己的获胜概率是多少。现请你帮忙求出这个值以验证聪聪的答案是否正确。

Input
输入的第1行包含1个正整数n。后面n-1行,每行3个整数x、y、w,表示x号点和y号点之间有一条边,上面的数是w。

Output
以即约分数形式输出这个概率(即“a/b”的形式,其中a和b必须互质。如果概率为1,输出“1/1”)。

Sample Input

5
1 2 1
1 3 2
1 4 1
2 5 3

Sample Output

13/25

说明
13组点对分别是(1,1) (2,2) (2,3) (2,5) (3,2) (3,3) (3,4) (3,5) (4,3) (4,4) (5,2) (5,3) (5,5)。

【数据规模】
对于100%的数据,n<=20000。

思路
显然所有点对数为 n ∗ n n*n nn,我们只需要求长度和为3的倍数的路径数就好了。所以上点分治

代码
这一段是分治部分的函数。其中cal()用于计算某棵子树的所有路径条数。vis数组用于标记访问过的点,同时也在cal()中用于判断子树的分割边界。
注意到ans+=cal(root,0); 和对于下面遍历每个子节点时的ans-=cal(v,length[i]);
其中ans+=cal(root,0);是用于计算当前子树的所有可能路径,而ans-=cal(v,length[i]);是用于去除当前树根下的所有重复路径。(详细内容与cal()函数的实现有关)
calc_size()函数用于重新计算某个点对应的子树大小。先前看到有板子似乎省掉了这一步骤,但后来发现他只是把这一步骤合并到其它函数里了

void calc(int now)
{
	ans+=cal(root,0);
	vis[root]=1;
	for(int i=h[root];i;i=nxt[i])if(!vis[p[i]])
	{
		MX=1e9;
		int v=p[i];
		ans-=cal(v,length[i]);
		calc_size(v,now);
		size=siz[v];
		find_root(v,now);
		calc(v);
	}
}

下面是cal()函数
其中dis()用于记录子树中所有点到当前root的距离,length为初始距离。最后的返回值就是子树中任意两个点距离和为3的倍数的数量。
dfs函数能够 O ( n ) O(n) On的遍历子树,其中n为子树大小。
考虑到先前的ans+=cal(root,0);,我们假设有这样u->v,v->k,v->j,且(|v->k|+|u->v|+|u->v|+|v->j|)%3=0,那么在cal(root,0)中就会错误地将ans+1。而实际上j到k的路径应为|v->j|+|v->k|。所以我们需要想办法减去这条错误合并的路径。这里解决办法就是ans-=cal(v,length[i]);相当于在v对应的子树中重新找到这条路径,然后减去它。
显然,在ans+=cal(root,0);ans-=cal(v,length[i]);中dis[v]的值都相同,进而使得属于v的子树的点的dis值都相同,所以cal(v,length[i])返回的就是所有在v的子树中被错误合并的路径条数。

long long cal(int root,int length)
{
	cz[0]=cz[1]=cz[2]=0;
	dis[root]=length;
	cz[dis[root]%3]++;
	dfs(root,0);
	return (cz[1]*cz[2]*2)+cz[0]*(cz[0]);
}

AC代码

#include
using namespace std;
typedef long long ll;
const int maxn=1e6+5;
int p[maxn],h[maxn],nxt[maxn],length[maxn],dis[maxn],cz[5];
int siz[maxn],vis[maxn],q[maxn];
int root,n,size,MX; 
long long ans;
int tot=1;
inline void addedge(int a,int b,int c)
{
	p[tot]=b;nxt[tot]=h[a];length[tot]=c;h[a]=tot++;
}
void find_root(int now,int f)
{
	int mi=0;
	mi=size-siz[now];
	for(int i=h[now];i;i=nxt[i])
	if(!vis[p[i]]&&p[i]!=f)
	{	
		find_root(p[i],now);
		mi=max(mi,siz[p[i]]);
	}
	if(MX>mi) 
	{
		MX=mi;root=now;
	} 
	return ;
}
void calc_size(int now,int f)
{
	siz[now]=1;
	for(int i=h[now];i;i=nxt[i])
	{
		if(p[i]!=f&&!vis[p[i]])
		{
			calc_size(p[i],now);
			siz[now]+=siz[p[i]];
		}
	}
	return ;
}
void dfs(int now,int fa)
{
	for(int i=h[now];i;i=nxt[i])
	{
		if(!vis[p[i]]&&p[i]!=fa)
		{
			dis[p[i]]=dis[now]+length[i];
			cz[dis[p[i]]%3]++;
			dfs(p[i],now);
		}
	}
	return;
}
long long cal(int root,int length)
{
	cz[0]=cz[1]=cz[2]=0;
	dis[root]=length;
	cz[dis[root]%3]++;
	dfs(root,0);
	return (cz[1]*cz[2]*2)+cz[0]*(cz[0]);
}
void calc(int now)
{
	ans+=cal(root,0);
	vis[root]=1;
	for(int i=h[root];i;i=nxt[i])if(!vis[p[i]])
	{
		MX=1e9;
		int v=p[i];
		ans-=cal(v,length[i]);
		calc_size(v,now);
		size=siz[v];
		find_root(v,now);
		calc(v);
	}
}
long long gcd(long long a,long long b){return (!a)?b:gcd(b%a,a);}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<n;i++)
	{
		int x,y,w;
		scanf("%d%d%d",&x,&y,&w);
		addedge(x,y,w);addedge(y,x,w);
	}
	MX=1e9;
	calc_size(1,0);
	size=siz[1];
	find_root(1,0);
	calc(1);
	long long pz=gcd(ans,(long long)n*n);
	printf("%lld/%lld",ans/pz,(long long)n*n/pz);
	return 0;
}

你可能感兴趣的:(点分治)