[SHOI2015]聚变反应炉

好久没有搞过树形dp的题了,它对新人很不友好,我就来补一发超详细的题解吧。

一、题目

点此看题
题意
给定一棵树,其中每个号节点如果被点亮,就会对周围相连的节点发出ci格能量,点亮第i个节点需要的能量点数为di。问能点亮整棵树的最小能量花费。
数据范围
对于 50 % 50\% 50%的数据, max ⁡ c i < = 1 , n < = 100000 \max{ci}<=1,n<=100000 maxci<=1,n<=100000
对于另外 50 % 50\% 50%的数据, max ⁡ c i < = 5 , n < = 2000 \max{ci}<=5,n<=2000 maxci<=5,n<=2000

二、解法

第一眼看到这个数据范围,觉得特别神奇,我们考虑数据分治。
0x01 贪心
先来考虑第一个范围,因为此时的 c c c只包含0或1,我们考虑 c = 1 c=1 c=1的点的点亮顺序。无论怎么点亮,减少的能量花费总量都是一定的(可以自举例子加深理解),所以最优解和顺序无关,我们只需优先点亮 c = 1 c=1 c=1的点即可。
0x02 树形dp
第二个范围就没有那么容易了,受贪心的启发,我们考虑点亮顺序。利用树形dp,我们对局部的点亮顺序进行考虑我们定义 d p [ i ] [ 0 / 1 ] dp[i][0/1] dp[i][0/1]为对于第 i i i个节点先染 它 / / /它的父亲,并把 i i i i i i的子树染完的最小花费。
显然答案为 d p [ 1 ] [ 0 ] dp[1][0] dp[1][0] 1 1 1的父亲不存在,用它统计最优答案)。
现在我们考虑对 d p dp dp数组的转移。
由于 c c c并不是很大,我们可以把它作为一个维度,为了辅助转移,我们再定义 t m p [ s ] tmp[s] tmp[s]为染完 u u u的子树,并且接受了儿子们值为 s s s的能量传递的最小花费(这里并不用去管 u u u是否被染色),我们先遍历完子树,再用滚动数组的形式合并子树的信息,详细操作见下。

memset(tmp,0x3f,sizeof tmp);
tmp[0][0]=0;//初始化
for(int i=f[u];i;i=e[i].next)//枚举每个儿子
{	
	int v=e[i].v;
	if(v==fa) continue;
	cur^=1;//滚动数组,用tmp[cur^1]更新tmp[cur]
	memset(tmp[cur],0x3f,sizeof tmp[cur]);
	for(int j=0;j<=sum-c[v];j++)//此时我们可以使用dp[v]
	{
            tmp[cur][j+c[v]]=min(tmp[cur][j+c[v]],tmp[cur^1][j]+dp[v][0]);//先点亮v
            tmp[cur][j]=min(tmp[cur][j],tmp[cur^1][j]+dp[v][1]);//先点亮u
     	}
}

拿到了 t m p tmp tmp之后,更新就要简单一些了,我们可写出如下转移:
d p [ u ] [ 0 ] = m i n ( d p [ u ] [ 0 ] , m a x ( t m p [ i ] , t m p [ i ] − i + d [ u ] ) ) ; dp[u][0]=min(dp[u][0],max(tmp[i],tmp[i]-i+d[u])); dp[u][0]=min(dp[u][0],max(tmp[i],tmp[i]i+d[u]));
d p [ u ] [ 1 ] = m i n ( d p [ u ] [ 1 ] , m a x ( t m p [ i ] , t m p [ i ] − i + d [ u ] − c [ f a ] ) ) ; dp[u][1]=min(dp[u][1],max(tmp[i],tmp[i]-i+d[u]-c[fa])); dp[u][1]=min(dp[u][1],max(tmp[i],tmp[i]i+d[u]c[fa]));
其中 i i i是从儿子得到的能量,我们防止能量溢出(即有多余的能量),故取max。

0x3f 代码
作者口胡可能难以理解,更多请看完整版代码,也可以单独向作者提问qwq。

#include 
#include 
const int MAXN = 100005;
const int MAXM = 2005;
int read()
{
    int x=0,flag=1;
    char c;
    while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
    while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*flag;
}
int n,tot,ans,mc,d[MAXN],c[MAXN],f[MAXN];
struct edge
{
    int v,next;
} e[MAXN*2];
int max(int x,int y)
{
    if(x<y) return y;
    return x;
}
int min(int x,int y)
{
    if(x<y) return x;
    return y;
}
int dp[MAXM][2],tmp[2][5*MAXM];
void dfs(int u,int fa)
{
    int sum=0,cur=0;
    for(int i=f[u]; i; i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa) continue;
        dfs(v,u);
        sum+=c[v];
    }
    memset(tmp,0x3f,sizeof tmp);
    tmp[0][0]=0;
    for(int i=f[u]; i; i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa) continue;
        cur^=1;
        memset(tmp[cur],0x3f,sizeof tmp[cur]);
        for(int j=0; j<=sum-c[v]; j++)
        {
            tmp[cur][j+c[v]]=min(tmp[cur][j+c[v]],tmp[cur^1][j]+dp[v][0]);
            tmp[cur][j]=min(tmp[cur][j],tmp[cur^1][j]+dp[v][1]);
        }
    }
    dp[u][0]=dp[u][1]=0x3f3f3f3f;
    for(int i=0; i<=sum; i++)
    {
        dp[u][0]=min(dp[u][0],max(tmp[cur][i],tmp[cur][i]-i+d[u]));
        dp[u][1]=min(dp[u][1],max(tmp[cur][i],tmp[cur][i]-i+d[u]-c[fa]));
    }
    return ;
}
int main()
{
    n=read();
    for(int i=1; i<=n; i++)
        d[i]=read();
    for(int i=1; i<=n; i++)
        c[i]=read(),mc=max(mc,c[i]);
    for(int i=1; i<n; i++)
    {
        int u=read(),v=read();
        e[++tot]=edge{u,f[v]},f[v]=tot;
        e[++tot]=edge{v,f[u]},f[u]=tot;
    }
    if(mc<=1)
    {
        for(int i=1; i<=n; i++)
            if(c[i]==1)
            {
                ans+=d[i];
                d[i]=0;
                for(int j=f[i]; j; j=e[j].next)
                    d[e[j].v]--;
            }
        for(int i=1; i<=n; i++)
            if(d[i]>0)
                ans+=d[i];
        printf("%d\n",ans);
        return 0;
    }
    dfs(1,0);
    printf("%d\n",dp[1][0]);
    return 0;
}

你可能感兴趣的:(树形dp)