poj1741 Tree(树形dp)

Tree
Time Limit: 1000MS   Memory Limit: 30000K
Total Submissions: 14706   Accepted: 4781

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source

LouTiancheng@POJ

题意:给定n个点的树,求其中任意两个点的距离<=k的对数。
分析:分为两种情况,一种是不同支的情况,这种好处理一些,直接往上找到公共父亲权值相加就可以;另一种就是同支的情况,把重复的权值删去就行了,只是实现起来麻烦些。另外推荐一篇博客 详解,我也是看的他的,很厉害。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
const double eps = 1e-6;
const double pi = acos(-1.0);
const int INF = 0x3f3f3f3f;
const int MOD = 1000000007;
#define ll long long
#define CL(a,b) memset(a,b,sizeof(a))
#define MAXN 20010

struct node
{
    int v,len;//v是邻接点,len是权值
    int sum,bat;//sum是子节点个数,bat是拆当前点后子树的最大结点数  
    node *next;
}tree[MAXN],*head[MAXN],dp[MAXN];
int n,k,ans,pre,root,tot;
int vis[MAXN],dist[MAXN];
int size[MAXN],sign[MAXN];//size表示最大分支的结点数,sign是一个hash数组  

void init()
{
    ans = pre = 0;
    for(int i=0; i<MAXN; i++)
        vis[i]=0, head[i]=NULL;
}

void add(int a, int b, int c)
{
    tree[pre].v = b; tree[pre].len = c;
    tree[pre].next = head[a]; head[a] = &tree[pre++];
}

void dfs(int son, int father)
{
    dp[son].sum = dp[son].bat = 0;
    node *p = head[son];
    while(p != NULL)
    {
        if(p->v!=father&&vis[p->v]==0)
        {
            dfs(p->v, son);
            dp[son].sum += dp[p->v].sum;//累计子节点数
            dp[son].bat = max(dp[son].bat, dp[p->v].sum);//找最大分支
        }
        p = p->next;
    }
    dp[son].sum++;
    sign[tot] = son;//hash
    size[tot++] = dp[son].bat;//记录每个最大分支的结点数  
}

int GetRoot(int son)
{
    tot = 0; dfs(son, 0);
    int maxx=INF, maxi, cnt=dp[son].sum;
    for(int i=0; i<tot; i++)
    {
        size[i] = max(size[i], cnt-size[i]);
        if(size[i] < maxx)
        {
            maxx = size[i];
            maxi = sign[i];
        }
    }
    return maxi;
}

void GetDist(int son, int father, int dis)//保存每个结点到根结点的距离
{
    node *p = head[son];
    dist[tot++] = dis;
    while(p != NULL)
    {
        if(p->v!=father&&vis[p->v]==0&&dis+p->len<=k)
            GetDist(p->v, son, dis+p->len);
        p = p->next;
    }
}

void count1(int son)//不同支
{
    sort(dist, dist+tot);
    int left=0, right=tot-1;
    while(left < right)
    {
        if(dist[left]+dist[right] <= k)
            ans += (right - left), left++;
        else right--;
    }
}

void count2(int son)//同支
{
    vis[son] = 1;
    node *p = head[son];
    while(p != NULL)
    {
        if(vis[p->v] == 0)
        {
            tot = 0; GetDist(p->v, son, p->len);
            sort(dist, dist+tot);
            int left=0, right=tot-1;
            while(left < right)
            {
                if(dist[left]+dist[right] <= k)
                    ans -= (right - left), left++;
                else right--;
            }
        }
        p = p->next;
    }
}

int solve(int son, int father)
{
    root = GetRoot(son);
    tot=0;
    GetDist(root, 0, 0);
    count1(root);
    count2(root);
    node *p = head[root];
    while(p != NULL)
    {
        if(p->v!=father&&vis[p->v]==0)
            solve(p->v, root);
        p = p->next;
    }
}

int main()
{
    int a,b,c;
    while(scanf("%d%d",&n,&k),n+k)
    {
        init();
        for(int i=1; i<n; i++)
        {
            scanf("%d%d%d",&a,&b,&c);
            add(a, b, c);
            add(b, a, c);
        }
        solve(1, 0);
        printf("%d\n",ans);
    }
    return 0;
}


你可能感兴趣的:(树形DP)