点分治就是在一棵树中,将每个点分治……
基本概念:
点分治:将一棵无根树变成有根树,再分别处理每棵有根子树。
重心:在一棵树中,这个点的最大子树是所有点中最小的。也可以说是删除该点时,树内剩下的子树最大节点数最小。
size[i]表示以i为根的子树节点数量。
如何求重心??求出size,什么是定义,就怎么求。一般来说总(不是每次)时间复杂度为O(n)
找重心的代码(这里的代码都对应着下面的例题):
void findroot(int x,int fa)
{
size[x]=1;mx[x]=0;
for(int i=last[x];i;i=next[i])
{
if (to[i]==fa || bz[to[i]]) continue;
findroot(to[i],x);
size[x]+=size[to[i]];
mx[x]=max(mx[x],size[to[i]]);
}
mx[x]=max(mx[x],nn-size[x]);
if (mx[x]<mx[root]) root=x;
}
int main()
{
mx[0]=2147483647;
nn=n;//一开始的子树就是整棵树,大小为n
root=0;findroot(1,0);
}
代码中的nn即为这个子树的节点数。
为什么要找重心??
为了使时间复杂度少,每棵子树都要尽量小,那么根一定是重心。
给出一棵带边权的树,问有多少对点的距离<=Len
1、对于每个点,找出有多少对经过这个点的点对距离在Len内,总(不是每次)O(n),减去在子节点中算重的。
2、删掉这个点,对于剩下的一堆子树中再重复做。
为了保证时间复杂度,每次都要找重心。
代码:
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define ll long long
#define N 11000
#define clear(a) memset(a,0,sizeof(a));
using namespace std;
int n,nn,len,m,size[N],mx[N],last[N*10],next[N*10],to[N*10],tot,root,ans;
ll data[N*10];
bool bz[N];
ll deep[N],dep[N];
void putin(int x,int y,int z)
{
next[++tot]=last[x];last[x]=tot;data[tot]=z;to[tot]=y;
}
void findroot(int x,int fa)
{
size[x]=1;mx[x]=0;
for(int i=last[x];i;i=next[i])
{
if (to[i]==fa || bz[to[i]]) continue;
findroot(to[i],x);
size[x]+=size[to[i]];
mx[x]=max(mx[x],size[to[i]]);
}
mx[x]=max(mx[x],nn-size[x]);
if (mx[x]<mx[root]) root=x;
}
void getdeep(int x,int fa)
{
for(int i=last[x];i;i=next[i])
{
if (to[i]==fa || bz[to[i]]) continue;
dep[to[i]]=dep[x]+data[i];
getdeep(to[i],x);
}
deep[++tot]=dep[x];
}
int calc(int x)
{
int ans=0;
sort(deep+1,deep+tot+1);
int j=tot;int i=1;
while (deep[i]==0) i++;
for(;i<j;)
{
if (deep[i]+deep[j]-2<=len) {ans+=j-i;i++;}
else j--;
}
return ans;
}
void dg(int x,int fa)
{
tot=0;clear(dep);dep[x]=1;getdeep(x,fa);
ans+=calc(x);
bz[x]=1;
for(int i=last[x];i;i=next[i])
{
int k=to[i];
if (bz[k]==1) continue;
tot=0;clear(dep);dep[k]=data[i]+1;getdeep(k,x);
ans-=calc(k);
nn=size[k];
root=0;findroot(k,x);
dg(root,x);
}
}
int main()
{
scanf("%d%d",&n,&len);
fo(i,1,n-1)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
putin(x,y,z);putin(y,x,z);
}
mx[0]=2147483647;nn=n;root=0;findroot(1,0);
dg(root,0);
printf("%d",ans);
}