poj 2378 Tree Cutting (树形dp)

<a target=_blank href="http://http://poj.org/problem?id=2378"><span style="font-size:24px;">看题目请戳我</span></a>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
#define maxx   10050
int num[maxx];//记录孩子的个数
int dp[maxx];//记录把这个节点删除之后剩下的最大的连通度
int sum;
vector <int>root[maxx];
bool vis[maxx];
int dfs(int x)//返回这个数的孩子有几个
{
    int n=root[x].size();
    vis[x]=1;
    num[x]=1;
    for(int i=0;i<n;i++)
    {
        if(!vis[root[x][i]])
            num[x]+=dfs(root[x][i]);
    }
    return num[x];
}
void  woo(int m)
{
    int k;
    int n=root[m].size();
    vis[m]=1;
    for(int i=0;i<n;i++)
    {
        k=root[m][i];
        if(vis[k])//如果他已经被遍历过了,判断他的孩子多还是先辈多
        {
            dp[m]=max(dp[m],sum-num[m]);
        }
        else//把它孩子的个数赋予它
        {
            dp[m]=max(dp[m],num[k]);
            woo(k);
        }
    }
}
int main()
{
    while(scanf("%d",&sum)!=EOF)
    {
         for(int i=1;i<=sum;i++)
            root[i].clear();
        int a,b;
        int c=sum;
        memset(dp,0,sizeof(dp));
        for(int i=1;i<sum;i++)
        {
            scanf("%d%d",&a,&b);
            root[a].push_back(b);
             root[b].push_back(a);
        }
        memset(vis,0,sizeof(vis));
        dfs(1);
        memset(vis,0,sizeof(vis));
        woo(1);
        int flag=0;
        for(int i=1;i<=c;i++)
        {
            if(dp[i]<=c/2)
            {
                flag=1;
                printf("%d\n",i);
            }
        }
        if(!flag)
            puts("NONE");
    }

}

邻接表代码    邻接表超快.

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxx 200010
struct node
{
   int xx,next;
}g[maxx];
int dp[maxx];
int num[maxx];
int vis[maxx];
int head[maxx];
int n,cnt;
void add(int u,int v)
{
    g[cnt].next=head[u];
    g[cnt].xx=v;
    head[u]=cnt++;
}
int dfs(int x)
{
    vis[x]=1;
    num[x]=1;
    for(int i=head[x];i!=-1;i=g[i].next)
    {
        int k=g[i].xx;
        if(!vis[k])
        {
            num[x]+=dfs(k);
            dp[x]=max(num[k],dp[x]);
        }
    }
    dp[x]=max(dp[x],n-num[x]);
    return num[x];
}
int main()
{
    int a,b;
    scanf("%d",&n); cnt=0;
    memset(dp,0,sizeof(dp));
    memset(head,-1,sizeof(head));
    memset(num,0,sizeof(num));
    memset(vis,0,sizeof(vis));
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&a,&b);
        add(a,b);
        add(b,a);
    }
    dfs(1);
    for(int i=1;i<=n;i++)
    {
        if(dp[i]<=n/2)

            printf("%d\n",i);
    }
}


你可能感兴趣的:(dp)