poj2486 树形dp

先声明,我是用纯C写的
主要是理解这个动态转移方程,利用三维数组记录每个时刻的状态,dp[i][j][0]代表i结点有j步可以走,并返回i结点时的最大值
dp[i][j][1]代表i结点有j步可以走,不返回i结点时的最大值

dp[root][j][0] = MAX (dp[root][j][0] , dp[root][j-k][0] + dp[son][k-2][0]);//从s出发,要回到s,需要多走两步s-t,t-s,分配给t子树k步,其他子树j-k步,都返回

dp[root][j]][1] = MAX( dp[root][j][1] , dp[root][j-k][0] + dp[son][k-1][1]) ;//先遍历s的其他子树,回到s,遍历t子树,在当前子树t不返回,多走一步

dp[root][j][1] = MAX (dp[root][j][1] , dp[root][j-k][1] + dp[son][k-2][0]);//不回到s(去s的其他子树),在t子树返回,同样有多出两步

你可以这样理解,假设3个结点,一个根结点,左右各一个结点。那么总共只有3种走法:

1.左子树t走完,并返回,回到根结点,多了2步,再走右子树

2.先遍历右子树,回到根节点,走左子树t,不再返回,多一步

3.遍历左子树,返回多2步,走右子树

代码:

#include<stdio.h>

struct node
{
    int v;
    int next;
}tree[505];

int dp[205][505][2],head[505],val[505];
int n,k,len;

void Init()      //初始化函数
{
    int i,j,q;
    for(i=0;i<205;i++)
        for(j=0;j<505;j++)
            for(q=0;q<2;q++)
                dp[i][j][q]=0;
    for(i=0;i<505;i++)
    {
        head[i]=-1;
        val[i]=0;
    }
}

int max(int a,int b)   //返回最大值函数
{

    if(a>b)
        return a;
    else
        return b;
}

void add(int u,int v)
{
    tree[len].v=v;
    tree[len].next=head[u];
    head[u]=len++;
}

void dfs(int root,int mark)
{
    int i,j,t,son;
    for(i=head[root];i!=-1;i=tree[i].next)
    {   
        son=tree[i].v;
        if(son==mark) //
            continue;
        dfs(son,root);
        for(j=k;j>=1;j--)
            for(t=1;t<=j;t++)  //0返回 
            {/* 从根节点u走到v要耗费一步,若要从v回到u又需要再耗费一步 */  
                if(t>=2)
                {
                  dp[root][j][0]=max(dp[root][j][0],dp[root][j-t][0]+dp[son][t-2][0]);
                  dp[root][j][1]=max(dp[root][j][1],dp[root][j-t][1]+dp[son][t-2][0]);
                }
                dp[root][j][1]=max(dp[root][j][1],dp[root][j-t][0]+dp[son][t-1][1]);
            }

    }
}





int main()
{
    int a,b;
    int i,j;
    while(scanf("%d%d",&n,&k))
    {
        Init();
        for(i=1;i<=n;i++)
        {
            scanf("%d",&val[i]);
            for(j=0;j<=k;j++)
                dp[i][j][0]=dp[i][j][1]=val[i];
        }
        len=0;
        for(i=1;i<n;i++)
        {
            scanf("%d%d",&a,&b);
            add(a,b);
            add(b,a);
        }
        dfs(1,-1);
        printf("%d\n",max(dp[1][k][0],dp[1][k][1]));
    }

    return 0;
}

可惜的是,我去poj测定时,结果超时了,想半天没想出来,有思路的大神求告诉!

你可能感兴趣的:(dp,C语言,poj)