很久以前就学过树分治,但是掌握不熟练(其实是弃坑了)
所以现在重新拾起这个算法,终于填坑完成……
发现还是挺简单的
树分治,是用于统计树上路径的算法
POJ1741就是一个很好的例子
下面会以此题为例,详细讲解树分治
树分治分为两种:点分治与边分治
点分治:每次找一个点,分治所有以它的儿子为根的子树
边分治:每次找一条边,分治它连接的两个点为根的子树
由于边分治容易被特殊数据卡,所以一般使用点分治
下面也是用点分治实现
前面提到点分治的效率比较稳定,是因为它每次用来分治的点是有讲究的:选择树的“重心”来分治
树的重心:若将某点设为根后,最大子树最小,那么此点就是树的重心
用反证法易得:重心的最大子树不大于整棵树的一半
也就是说,每次选择重心分治,树的大小至少减去一半
这样点分治的复杂度最坏就是 O(NlogN)
就保证了效率的稳定性
根据重心的定义,可以很容易用一次DFS找到重心
代码实现:
void get_hvy(int x,int fa){
siz[x]=1;MAX[x]=0;
for (int j=lnk[x];j;j=nxt[j])
if (son[j]!=fa&&!vis[son[j]]){
get_hvy(son[j],x);
siz[x]+=siz[son[j]];
MAX[x]=max(MAX[x],siz[son[j]]);
}
MAX[x]=max(MAX[x],S-siz[x]);
if (!hvy||MAX[x]
(hvy表示当前的重心,S表示当前树的大小,都是全局变量)
再来看树分治的核心部分。
对于当前处理的树,统计经过树根的路径
但是有不符合题意的路径(路径的两个端点在同一子树,不是简单路径)
所以再以每个儿子为根,计算一次,减去即可(注意要带上权值w)
然后对子树找重心,递归处理
代码实现:
void get_ans(int x){
vis[x]=1;ans+=get_sum(x,0);
for (int j=lnk[x];j;j=nxt[j])
if (!vis[son[j]]){
ans-=get_sum(son[j],w[j]);
hvy=0;S=siz[son[j]];get_hvy(son[j],0);
get_ans(hvy);
}
}
具体如何统计经过树根的路径呢
先获得整棵树的深度,记为dep[]
那么对于任意点对(i,j)只要满足 dep[i]+dep[j]≤k 即可被统计
那么排序后用两个指针 O(N) 推一下就好了
代码实现:
void get_dep(int x,int fa){
now[++now[0]]=dep[x];
for (int j=lnk[x];j;j=nxt[j])
if (fa!=son[j]&&!vis[son[j]]){
dep[son[j]]=dep[x]+w[j];
get_dep(son[j],x);
}
}
int get_sum(int x,int dst){
now[0]=0;int res=0;
dep[x]=dst;get_dep(x,0);
sort(now+1,now+1+now[0]);
for (int i=1,j=now[0];i<=now[0];i++){
while (j>i&&now[i]+now[j]>k) j--;
res+=max(0,j-i);
}
return res;
}
完整代码:
#include
#include
#include
using namespace std;
const int maxn=10005,maxe=2*maxn;
int n,k,hvy,S,ans;
int tot,lnk[maxn],nxt[maxe],son[maxe],w[maxe];
int siz[maxn],now[maxn],dep[maxn],MAX[maxn];
bool vis[maxn];
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int red(){
int res=0,f=1;char ch=nc();
while (ch<'0'||'9'if (ch=='-') f=-f;ch=nc();}
while ('0'<=ch&&ch<='9') res=res*10+ch-48,ch=nc();
return res*f;
}
void add(int x,int y,int z){
son[++tot]=y;nxt[tot]=lnk[x];lnk[x]=tot;w[tot]=z;
}
void get_hvy(int x,int fa){
siz[x]=1;MAX[x]=0;
for (int j=lnk[x];j;j=nxt[j])
if (son[j]!=fa&&!vis[son[j]]){
get_hvy(son[j],x);
siz[x]+=siz[son[j]];
MAX[x]=max(MAX[x],siz[son[j]]);
}
MAX[x]=max(MAX[x],S-siz[x]);
if (!hvy||MAX[x]void get_dep(int x,int fa){
now[++now[0]]=dep[x];
for (int j=lnk[x];j;j=nxt[j])
if (fa!=son[j]&&!vis[son[j]]){
dep[son[j]]=dep[x]+w[j];
get_dep(son[j],x);
}
}
int get_sum(int x,int dst){
now[0]=0;int res=0;
dep[x]=dst;get_dep(x,0);
sort(now+1,now+1+now[0]);
for (int i=1,j=now[0];i<=now[0];i++){
while (j>i&&now[i]+now[j]>k) j--;
res+=max(0,j-i);
}
return res;
}
void get_ans(int x){
vis[x]=1;ans+=get_sum(x,0);
for (int j=lnk[x];j;j=nxt[j])
if (!vis[son[j]]){
ans-=get_sum(son[j],w[j]);
hvy=0;S=siz[son[j]];get_hvy(son[j],0);
get_ans(hvy);
}
}
int main(){
for (n=red(),k=red();n||k;n=red(),k=red()){
memset(vis,0,sizeof(vis));
memset(lnk,0,sizeof(lnk));tot=0;
for (int i=1,x,y,z;i0;S=n;get_hvy(1,0);
ans=0;get_ans(hvy);
printf("%d\n",ans);
}
return 0;
}
附:点分治相关题目BZOJ2152