poj 2486 树形DP 子树合并


大致思路是从根节点DFS下去后,处理完子树信息后,合并然后得出根节点信息,
但是这一题子节点信息的合并确实比较棘手.

如果我们尝试把从当前根节点走K步得到最大苹果数,划分为如下两种:
     
  go[t][i]代表节点t的所有子树上走i步不返回,取得的最大苹果数
  bk[t][i]代表节点t的所有子树上走i步并返回,取得的最大苹果数
 
数组二维go,bk go[t][i]代表节点t的所有子树上至多走i步不返回,取得的最大苹果数 bk[t][i]代表节点t的所有子树上至多走i步并返回,取得的最大苹果数 求节点为x,实行不断合并子树求最优值 当前合并到了q棵子树: go[x][i]就是这q棵子树上至多走i步不返回的最优值 bk[x][i]就是这q棵子树上至多走i步并返回的最优值 合并第q+1棵子树(不妨设第q+1棵子树的根为y)的时候,有 go[x][i] = max( bk[x][j]+go[y][i-j], bk[y][j],go[x][i-j] ), j=0.....i bk[x][i] = max( bk[x][j]+bk[y][i-j] ) j=0,.....i;

关于边界的初始化问题, 对于当前以x为根的树, 因为递归处理好了其子节点J为根的子树,
但是处理子树J的时候,把J节点看作0来处理,回朔到X节点,再处理J节点时再去计算从X走到J的苹果值.
这样,
  若计算对于go[j][k] 则需要添加一步由x到j的路径,所以go[j][k] = go[j][k-1]+Apple[j]
  若计算对于bk[j][k] 则需要添加来回一趟,共两步从x到j,再回到x的路径,所以bk[j][k] = bk[j][k-2]+Apple[j]
还需要注意, bk[j][1] = bk[j][0] = 0 , go[j][0] = 0, 前者无法走回,后者还没开始走,当然为0
// Code by yefeng1627

// Time 2013-1-17

#include<stdio.h>

#include<stdlib.h>

#include<string.h>

#include<vector>

using namespace std;



const int N = 110;

int go[N][N<<1],bk[N][N<<1],tmp1[N<<1],tmp2[N<<1];

int n, k, Apple[N];

vector<int> Q[N];



int max(int a,int b) {return a>b?a:b;}



void DP(int x, int y )

{

    for(int i = 0; i <= k; i++) tmp1[i]=tmp2[i]=0;    

    for(int i = k; i >= 0; i--)

        for(int j = 0; j <= i; j++)

            tmp1[i] = max(tmp1[i],max(bk[x][j]+go[y][i-j],go[x][j]+bk[y][i-j]) );

    for(int i = k; i >= 0; i--)

        for(int j = 0; j <= i; j++)

            tmp2[i] = max(tmp2[i],bk[x][j]+bk[y][i-j]);

    for(int i = 0; i <= k; i++) 

        go[x][i] = tmp1[i], bk[x][i] = tmp2[i];

}

void solve(int x, int fa)

{

    int y;

    for(int i = 0; i < (int)Q[x].size(); i++ )

        if( (y=Q[x][i]) != fa )     

        {

            solve( y, x );

            for(int L = k; L >= 2; L-- ) bk[y][L] = bk[y][L-2]+Apple[y];

            bk[y][1] = bk[y][0] = go[y][0] = 0;    

            for(int L = k; L >= 1; L-- ) go[y][L] = go[y][L-1]+Apple[y];

            DP( x, y );

        }

}

int main()

{

    while( scanf("%d%d",&n, &k) != EOF )

    {

        memset(go,0,sizeof(go));

        memset(bk,0,sizeof(bk));

        for(int i = 0; i < n; i++)

        {

            scanf("%d", &Apple[i] );    

            Q[i].clear();    

        }

        int a, b;

        for(int i = 0; i < n-1; i++)

        {    

            scanf("%d%d",&a,&b);

            a--; b--;

            Q[a].push_back(b);

            Q[b].push_back(a);

        }

        solve(0, -1);    

        printf("%d\n", go[0][k]+Apple[0] );        

    }

    return 0;

}

 

你可能感兴趣的:(poj)