POJ 1741 Tree 树分治 dp

C - Tree
Time Limit:1000MS     Memory Limit:30000KB     64bit IO Format:%I64d & %I64u
Submit  Status  Practice  POJ 1741
Appoint description:  System Crawler  (2016-04-20)

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

问题描述:

POJ 1741 Tree 树分治 dp_第1张图片
推荐qzc的论文http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html
一开始用lca和dfs超时了
看了解题报告才知道要用树分治分治的中心思想是把找到一个点删去该点获得的最大子树最小这样就能二分解决问题?
还是有点晕
看对拍的代码:
#pragma warning(disable:4786)//使命名长度不受限制
#pragma comment(linker, "/STACK:102400000,102400000")//手工开栈
#include <map>
#include <set>
#include <queue>
#include <cmath>
#include <stack>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <stdlib.h>
#include <iostream>
#include <algorithm>
#define rd(x) scanf("%d",&x)
#define rd2(x,y) scanf("%d%d",&x,&y)
#define rd3(x,y,z) scanf("%d%d%d,&x,&y,&z)
#define rdl(x) scanf("%I64d,&x);
#define rds(x) scanf("%s",x)
#define rdc(x) scanf("%c",&x)
#define ll long long int
#define ull unsigned long long
#define maxn 100500
#define mod 1000000007
#define INF 0x3f3f3f3f //int 最大值
#define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i)
#define MT(x,i) memset(x,i,sizeof(x))
#define PI  acos(-1.0)
#define E  exp(1)
#define eps 1e-8
ll gcd(ll a,ll b){return b==0?a:gcd(b,a%b);}
ll mul(ll a,ll b,ll p){ll sum=0;for(;b;a=(a+a)%p,b>>=1)if(b&1)sum=(sum+a)%p;return sum;}
inline void Scan(int &x) {
      char c;while((c=getchar())<'0' || c>'9');x=c-'0';
      while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
}
using namespace std;
struct N{
    int to,next,len;
}my[maxn];
int head[maxn],tot;
int dep[maxn],Size[maxn];
bool vis[maxn];
int n,k,le,ri,minn;
void add(int u,int v,int w){
    my[tot].to=v;my[tot].len=w;
    my[tot].next=head[u];head[u]=tot++;
}
void init(){memset(head,-1,sizeof(head));tot=0;}
int get_size(int u,int fa){//找到比当前结点的儿子数
    Size[u]=1;
    for(int i=head[u];i+1;i=my[i].next){
        int v=my[i].to;
        if(fa==v||vis[v])continue;
        Size[u]+=get_size(v,u);
    }
    return Size[u];
}
void get_root(int u,int fa,int num,int &root){///找到重心也就是分治的点
    int maxx=num-Size[u];
    for(int i=head[u];i+1;i=my[i].next){
        int v=my[i].to;
        if(fa==v||vis[v])continue;
        get_root(v,u,num,root);
        maxx=max(maxx,Size[v]);
    }
    if(maxx<minn){
        minn=maxx;
        root=u;
    }
}
void find_depth(int u,int fa,int d){///记录距离
    dep[ri++]=d;
    for(int i=head[u];i+1;i=my[i].next){
        int v=my[i].to;
        if(fa==v||vis[v])continue;
        find_depth(v,u,d+my[i].len);
    }
}
int get_depth(int a,int b){///得到2点的距离
    sort(dep+a,dep+b);
    int ret=0,e=b-1;
    for(int i=a;i<b;++i){
        if(dep[i]>k)break;
        while(e>=a&&dep[e]+dep[i]>k)e--;
        ret+=e-a+1;
        if(e>i)ret--;
    }
    return ret>>1;
}
int dfs(int u){
 //   cout<<u<<'\12';
    int ret=0;
    minn=INF;
    int num=get_size(u,-1);
    int root;
    get_root(u,-1,num,root);
    vis[root]=true;
    for(int i=head[root];i+1;i=my[i].next){
        int v=my[i].to;
        if(vis[v])continue;
        ret+=dfs(v);
    }
    le=0,ri=0;
    for(int i=head[root];i+1;i=my[i].next){
        int v=my[i].to;
        if(vis[v])continue;
        find_depth(v,root,my[i].len);
        ret-=get_depth(le,ri);
        le=ri;
    }
    ret+=get_depth(0,ri);
    for(int i=0;i<ri;++i)
        if(dep[i]<=k)ret++;
        else break;
    vis[root]=false;
    return ret;
}
int u,v,len;
int main(){
    while(rd2(n,k)&&n+k){
        init();
        FOR(i,2,n){
            rd2(u,v);rd(len);
            add(u,v,len);
            add(v,u,len);
        }
        memset(vis,false,sizeof(vis));
        printf("%d\n",dfs(1));
    }
    return 0;
}



你可能感兴趣的:(C++,dp)