hdu3534,个人认为很经典的树形dp

题目大意为,求一个树的直径(最长路),以及直径的数量

朴素的dp只能找出某点开始的最长路径,但这个最长路径却不一定是树的直径,本弱先开始就想简单了,一直wa

直到我看了某位大牛的题解。。。

按照那位大牛的思路,我们来考虑直径的构成:

情况1:由某叶子节点出发产生的最长路径直接构成

情况2:由某有多个儿子的节点出发产生的两条长路径组成,这其中,又可以分为两条长路径长度相等与否两种情况

所以 在dp的时候,我们需要记录每个节点出发产生的最长路径和次长路径,以及他们的数量,数量的统计也是非常麻烦

详细请见代码:

#include<stdio.h>

#include<iostream>

#include<stdlib.h>

#include<math.h>

#include<ctype.h>

#include<algorithm>

#include<string>

#include<string.h>

#include<queue>

#define mod 998244353

#define MAX 100000000

using namespace std;

int t,n,m,p,k,tt,f;

int x;



int head[10010];

typedef struct Node

{

    int en;

    int value;

    int next;

}node;

node edge[20010];

typedef struct DPnode

{

    int dp1,dp2,len,nn;

    int n1,n2;

}DP;

DP dp[10010];

void ini()

{

    int x,y,z;

    for(int i=1;i<=n-1;i++)

    {

        scanf("%d%d%d",&x,&y,&k);

        edge[2*i-1].en=y;

        edge[2*i-1].next=head[x];

        edge[2*i-1].value=k;

        head[x]=2*i-1;

        edge[2*i].en=x;

        edge[2*i].next=head[y];

        edge[2*i].value=k;

        head[y]=2*i;

    }

}

void dfs(int s,int p)

{

    dp[s].dp1=dp[s].dp2=dp[s].len=dp[s].n1=dp[s].n2=dp[s].nn=0;

    int leaf=1;

    for(int i=head[s];i;i=edge[i].next)

    {

        int q=edge[i].en;

        if(q==p)

            continue;

        leaf=0;

        dfs(q,s);

        int tmp=dp[q].dp1+edge[i].value;

        if(tmp>dp[s].dp1)

        {

            dp[s].dp2=dp[s].dp1;

            dp[s].n2=dp[s].n1;

            dp[s].dp1=tmp;

            dp[s].n1=dp[q].n1;

        }

        else if(tmp==dp[s].dp1)

        {

            dp[s].n1+=dp[q].n1;

        }

        else if(tmp>dp[s].dp2)

        {

            dp[s].dp2=tmp;

            dp[s].n2=dp[q].n1;

        }

        else if(tmp==dp[s].dp2)

        {

            dp[s].n2+=dp[q].n1;

        }

    }

    if(leaf)

    {

        dp[s].n1=1;dp[s].nn=1;

        dp[s].len=0;

        dp[s].dp1=0;

        return;

    }

    int c1=0,c2=0;

    for(int i=head[s];i;i=edge[i].next)

    {

        int q=edge[i].en;

        if(q==p)

            continue;

        int tmp=dp[q].dp1+edge[i].value;

        if(tmp==dp[s].dp1)

            c1++;

        else if(tmp==dp[s].dp2&&dp[s].dp2)

            c2++;

    }

    if(c1>1)

    {

        dp[s].len=dp[s].dp1*2;

        int sum=0;

        for(int i=head[s];i;i=edge[i].next)

        {

            int q=edge[i].en;

            if(q==p)

                continue;

            if(dp[q].dp1+edge[i].value==dp[s].dp1)

            {

                dp[s].nn+=sum*dp[q].n1;

                sum+=dp[q].n1;

            }

        }

    }

    else if(c2>0)

    {

        dp[s].len=dp[s].dp1+dp[s].dp2;

        for(int i=head[s];i;i=edge[i].next)

        {

            int q=edge[i].en;

            if(q==p)

                continue;

            if(dp[q].dp1+edge[i].value==dp[s].dp2)

            {

                dp[s].nn+=dp[s].n1*dp[q].n1;

            }

        }

    }

    else

    {

        dp[s].len=dp[s].dp1;

        dp[s].nn=dp[s].n1;

    }

    return ;

}

void solve()

{

    int ans=0;

    int num=0;

    for(int i=1;i<=n;i++)

    {

        if(dp[i].len>ans)

        {

            ans=dp[i].len;

            num=dp[i].nn;

        }

        else if(dp[i].len==ans)

        {

            num+=dp[i].nn;

        }

    }

    printf("%d %d\n",ans,num);

}

int main()

{

    while(scanf("%d",&n)!=EOF)

    {

        memset(head,0,sizeof(head));

        ini();

        dfs(1,0);

        solve();

    }

    return 0;

}

 

你可能感兴趣的:(HDU)