我们通过一个题目引入,这也是树形dp的一道经典例题——没有上司的舞会。
题目描述
Ural 大学有 N N N 名职员,编号为 1 ∼ N 1∼N 1∼N。他们的关系就像一棵以校长为根的树,父节点就是子节点的直接上司。每个职员有一个快乐指数,用整数 H i H_i Hi 给出,其中 1 ≤ i ≤ N 1≤i≤N 1≤i≤N。现在要召开一场周年庆宴会,不过,没有职员愿意和直接上司一起参会。在满足这个条件的前提下,主办方希望邀请一部分职员参会,使得所有参会职员的快乐指数总和最大,求这个最大值。
输入格式
第一行一个整数 N N N。
接下来 N N N行,第 i i i 行表示 i i i 号职员的快乐指数 H i H_i Hi。
接下来 N − 1 N−1 N−1 行,每行输入一对整数 L , K L,K L,K表示 K K K 是$ L$ 的直接上司。(注意一下,后一个数是前一个数的父节点,不要搞反)。
输出格式
输出最大的快乐指数。
数据范围
1 ≤ N ≤ 6000 1≤N≤6000 1≤N≤6000
− 128 ≤ H i ≤ 127 −128≤H_i≤127 −128≤Hi≤127
题目分析
先按照动态规划的板子思考一下这道题目,
第一个阶段定义dp数组
(1)缩小规模。
搞清这道题的规模是什么,按照之前的思路,我要考虑n个员工每个员工是否参加舞会,规模应该是这n个员工,但是有一点不同,员工之间的关系是一个树状的,也就是我不能和之前一样从1遍历到n,而是遍历一棵树,树的遍历一般从根节点开始,向子节点遍历,那么这里的规模依然是员工,但是dp[i]表示的不是前i个员工,而是以i为根节点的树。
(2)考虑限制。
这道题目的限制就是当父节点参加了舞会时,儿子节点不能参加舞会,那么我在判断当前节点是否能参加舞会时,我需要知道它的父节点是否参加舞会,所以需要第二个维度,那么就是 d p [ i ] [ j ] dp[i][j] dp[i][j],其中j要么为0,要么为1。j为0表示当前节点i没有参加舞会,j为1表示当前节点i参加了舞会。
(3)定义dp数组。
d p [ i ] [ 0 ] dp[i][0] dp[i][0]表示以i为根节点且i没有参加舞会的子树获得的最大快乐指数。 d p [ i ] [ 1 ] dp[i][1] dp[i][1]表示以i为根节点且i参加舞会的子树获得的最大快乐指数。这里求什么dp数组就表示什么。
第二个阶段推导状态转移方程
d p [ i ] [ 0 ] + = m a x ( d p [ j ] [ 0 ] , d p [ j ] [ 1 ] ) dp[i][0]+=max(dp[j][0],dp[j][1]) dp[i][0]+=max(dp[j][0],dp[j][1]),其中j是i的儿子节点,因为i没有被选择,所以儿子节点既可以被选中,又可以不被选中。
d p [ i ] [ 1 ] + = d p [ j ] [ 0 ] dp[i][1]+=dp[j][0] dp[i][1]+=dp[j][0],其中j是i的儿子节点,因为i被选择,所以儿子节点只可以不被选中。
第三个阶段写代码
(1)dp数组的初始化。这里dp数组最初始的状态应该是dp[i]中i为叶子节点的情况。这种情况
(2)递推dp数组
这里的递推方式就是树形dp特有的了,从根节点开始往下递推,其实这里是自上而下的递推,规模一开始是最大的,然后为了解决最大规模的问题,需要先把相应的子规模的答案求出来,也就是层层递归,递归到叶子节点后再层层返回即可。
详细说一下这一部分的代码,那么dfs(u)返回的是以u为根节点的子树的快乐值。对于 d p [ u ] [ 1 ] dp[u][1] dp[u][1]表示选择了节点i,那么就要加上节点i的快乐值,所以有dp[u][1] += a[u];
。然后遍历节点i的儿子节点,对儿子节点进行dfs,求以儿子节点为根的子树的快乐值,求出来后,也就是dfs结束后,用它来更新当前节点的快乐值,这里的更新就是刚刚我们推导的状态转移方程。代码如下,
private static void dfs(int u) {
// TODO Auto-generated method stub
dp[u][1] += a[u];
for (int i = 0; i < q[u].size(); i++) {
int to = q[u].get(i);
dfs(to);
dp[u][1] += dp[to][0];
dp[u][0] += Math.max(dp[to][0], dp[to][1]);
}
}
(3)答案的表示
m a x ( d p [ 根节点 ] [ 0 ] , d p [ 根节点 ] [ 1 ] ) max(dp[根节点][0],dp[根节点][1]) max(dp[根节点][0],dp[根节点][1])表示答案。
题目代码
关于代码还是有几点要说。
存图的方式,这里存图用的是比较简单的链表存图。关键代码是
static ArrayList<Integer>[] q;
q[i] = new ArrayList<Integer>();
for (int i = 1; i < a.length - 1; i++) {
sc.nextToken();
int u = (int)sc.nval;
sc.nextToken();
int v = (int)sc.nval;
visit[u] = true;
q[v].add(u);
}
找出哪个节点是根节点,也就是没有父节点的节点。
for (int i = 1; i < a.length - 1; i++) {
sc.nextToken();
int u = (int)sc.nval;
sc.nextToken();
int v = (int)sc.nval;
visit[u] = true;//标记u已经有父节点了
q[v].add(u);
}
int root = -1;
for (int i = 1; i < a.length; i++) {
if(!visit[i]) {
root = i;//找到根节点
}
}
题目代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.StreamTokenizer;
import java.util.ArrayList;
public class Main {
static boolean[] visit;
static long[][] dp;
static int[] a;
static ArrayList<Integer>[] q;
public static void main(String[] args) throws IOException{
StreamTokenizer sc=new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
sc.nextToken();
int n = (int)sc.nval;
a = new int[n + 1];
q = new ArrayList[n+1];
dp = new long[n + 1][2];
visit = new boolean[n + 1];
for (int i = 1; i < a.length; i++) {
sc.nextToken();
a[i] = (int)sc.nval;
q[i] = new ArrayList<Integer>();//初始化链表
}
for (int i = 1; i < a.length - 1; i++) {
sc.nextToken();
int u = (int)sc.nval;
sc.nextToken();
int v = (int)sc.nval;
visit[u] = true;//标记u已经有父节点了
q[v].add(u);//存图
}
int root = -1;
for (int i = 1; i < a.length; i++) {
if(!visit[i]) {
root = i;//找到根节点
}
}
dfs(root);//从根节点遍历
System.out.println( Math.max(dp[root][1], dp[root][0]));
}
private static void dfs(int u) {
// TODO Auto-generated method stub
dp[u][1] += a[u];
for (int i = 0; i < q[u].size(); i++) {
int to = q[u].get(i);
dfs(to);
dp[u][1] += dp[to][0];
dp[u][0] += Math.max(dp[to][0], dp[to][1]);
}
}
}