树分治的第一题,参考了论文及别人的博客
首先两点之间距离小于等于K,一定只有三种情况(假设设定了一个根)
1.两点都在根的某一棵子树
2.两点在不同的子树
3.一个点为根,另一个点为树中的一个点
1可有2,3解得
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <ctime>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define INF 0x3f3f3f3f
#define inf -0x3f3f3f3f
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem0(a) memset(a,0,sizeof(a))
#define mem1(a) memset(a,-1,sizeof(a))
#define mem(a, b) memset(a, b, sizeof(a))
typedef long long ll;
#define N 10010
typedef pair<int,int> PI;
vector<PI>G[N];
int siz[N];
int f[N],can[N],d[N];
int list[N];
int ans,K,t1,l1,l2;
void init(int n){
for(int i=1;i<=n;i++)
G[i].clear();
mem1(can);
ans=0;
}
void dfs1(int x,int fa){
siz[x]=1;
list[++t1]=x;
for(int i=0;i<G[x].size();i++){
int v=G[x][i].first;
if(can[v]&&v!=fa){
dfs1(v,x);
f[v]=x;
siz[x]+=siz[v];
}
}
}
int get_root(int x,int fa){
t1=0;
dfs1(x,fa);
int pos,tmp=INF;
for(int i=1;i<=t1;i++){
int d1=0;
int y=list[i];
for(int j=0;j<G[y].size();j++){
int v=G[y][j].first;
if(v!=f[y]&&can[v])
d1=max(siz[v],d1);
}
if(y!=x)
d1=max(d1,siz[x]-siz[y]);
if(d1<tmp){
pos=y;
tmp=d1;
}
}
return pos;
}
void dfs2(int x,int fa,int dis){
list[++l1]=x;
d[x]=dis;
for(int i=0;i<G[x].size();i++){
int v=G[x][i].first;
if(can[v]&&v!=fa)
dfs2(v,x,dis+G[x][i].second);
}
}
inline int cmp(int i,int j){
return d[i] < d[j];
}
int getans(int *a,int l,int r){
int j=r;
int ret=0;
for(int i=l;i<=r;i++){
while(d[a[i]]+d[a[j]]>K&&j>i)
j--;
ret+=(j-i);
if(j==i)
break;
}
return ret;
}
void work(int x,int fa){
int root=get_root(x,fa);
l1=l2=0;
for(int i=0;i<G[root].size();i++){
int v=G[root][i].first;
if(can[v]){
l2=l1;
dfs2(v,root,G[root][i].second);
sort(list+l2+1,list+l1+1,cmp);
ans-=getans(list,l2+1,l1);
}
}
list[++l1]=root;
d[root]=0;
sort(list+1,list+1+l1,cmp);
ans+=getans(list,1,l1);
can[root]=0;
for(int i=0;i<G[root].size();i++){
int v=G[root][i].first;
if(can[v])
work(v,root);
}
}
int main(){
int n,m;
while(scanf("%d%d",&n,&K)!=EOF){
if(n==0&&K==0)
break;
init(n);
int u,v,w;
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
G[u].push_back(PI(v,w));
G[v].push_back(PI(u,w));
}
work(1,0);
printf("%d\n",ans);
}
return 0;
}