树形dp之节点对 poj1741

题目很经典,完全看了别人的解析啊,树的分治,算法复杂度o(nlogn^2)

#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#include<stack>
#include<vector>
#include<climits>
#include<map>
using namespace std;

#define rep(i,n) for(int i=0; i<(n); ++i)
#define repf(i,n,m) for(int i=(n); i<=(m); ++i)
#define repd(i,n,m) for(int i=(n); i>=(m); --i)
#define ll long long
#define inf 1000000000
#define exp 0.000000001
#define N 10010 
struct node
{
       int y,pre,len;
};
node a[N*2];
int size[N],next[N],pre[N];
int n,m,len,ans,root,Min;
bool vis[N];
int dis[N];

void init()
{
     len=1; ans=0;Min=inf;
     memset(pre,-1,sizeof(pre));
     memset(vis,false,sizeof(vis));
}
void addpage(int x,int y,int w)
{
     a[len].y=y;
     a[len].len=w;
     a[len].pre=pre[x];
     pre[x]=len++;
 }
 
void dfssize(int u,int fa)
{
     size[u]=1; next[u]=0;
     for(int i=pre[u]; i!=-1; i=a[i].pre)
     {
          int y=a[i].y;
          if(y==fa || vis[y]) continue;
          dfssize(y,u);
          size[u]+=size[y];
          next[u]=max(next[u],size[y]);
     }
 }
 void dfsroot(int u,int v,int fa)
 {
      next[v]=max(size[u]-size[v],next[v]);
      if(next[v]<Min)  Min=next[v],root=v;
      for(int i=pre[v]; i!=-1; i=a[i].pre)
      {
              int y=a[i].y;
              if(y==fa || vis[y]) continue;
              dfsroot(u,y,v);
      }
  }
  
void dfs(int u,int d,int fa)
{
     dis[len++]=d;
     for(int i=pre[u]; i!=-1; i=a[i].pre)
     {
             int y=a[i].y;
             if(y==fa || vis[y]) continue;
             dfs(y,d+a[i].len,u);
     }
}
  int calc(int u,int d)
  {
      len=0;
      dfs(u,d,-1);
      sort(dis,dis+len);
      int sum=0;
      int i=0,j=len-1;
      while(i<j)
      {
                while(dis[i]+dis[j]>m && i<j) --j;
                sum+=j-i;
                i++;
      }
      return sum;
  }
void dfs(int u)
{
     Min=inf;
    dfssize(u,-1);     
    dfsroot(u,u,-1);//寻找重心的 
    ans+=calc(root,0);
    vis[root]=true;
    int ro=root;
    for(int i=pre[ro]; i!=-1; i=a[i].pre)
    {
            int y=a[i].y;
            if(vis[y]) continue;
            ans-=calc(y,a[i].len);
            dfs(y);
    }
}

int main()
{
    while(scanf("%d%d",&n,&m))
    {
          if(n==0 && m==0) break;
          init();
          int x,y,z;
          repf(i,1,n-1)
          {
             scanf("%d%d%d",&x,&y,&z);
             addpage(x,y,z);
             addpage(y,x,z);
          }
          dfs(1);
          printf("%d\n",ans);
    }
    return 0;
}


你可能感兴趣的:(树形dp之节点对 poj1741)