以为很水的一道题,花了大半天的时间才搞定,比赛的时候卡在这题上了,伤不起啊。。。
题意:给一棵树,每个结点中有礼物,每个礼物有一个权值,某些结点中会有陷阱,你可以从任何一点出发,每个结点最多只能经过一次,最多掉进陷阱C次,求出可获得的礼物的最大值。
思路:典型的树形DP ,状态可用dp[x][y][z]来表示,x代表以x为根结点的子树,y代表恰好经过了几个陷阱,z代表方向(0,1),表示从此子树进来或是出去。
dp[x][y][0]的含义是从子树x中出来,恰好经过y个陷阱所能获得的最大礼物值,dp[x][y][1]的含义是进入子树x,恰好经过y个陷阱所能获得的最大礼物值。
状态转移方程:
如果结点x有陷阱:dp[x][y][z]=max(dp[u][y-1][z])
如果结点x无陷阱:dp[x][y][z]=max(dp[u][y][z]),u是x的子结点
先把无根树转化为有根树,用一个全局变量记录答案,每求出一棵子树的dp值,就把答案更新一次:一种找出该子树中两棵不同的子子树,一颗出树,一颗入树,在两树通过父结点x连到一起,在该路径上的陷阱数不超过C的情况下,通过枚举两棵子子树上的陷阱数来更新答案,另一种是找出一棵子子树上的单条路径,在陷阱数不超过C的情况下,更新答案。
一定要注意该子树根结点有陷阱和无陷阱时要分开处理。
整体采用记忆化搜索实现。
细节太多,不再一一叙述,详情见代码:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define LL long long using namespace std; vector<int>v[50005]; int N,C,b[50005],r[50005]; LL a[50005],res,dp[50005][4][2]; bool vis[50005][4][2]; void root(int x) { for (int i=0; i<v[x].size(); ++i) if (v[x][i]!=r[x]) { r[v[x][i]]=x; root(v[x][i]); } } LL dfs(int x,int y,int z); void update(int x) { vector<LL>ls[4],rs[4],e[4]; if (v[x].size()==1 && x) { if (!b[x]) res=max(res,a[x]); else if (C) res=max(res,a[x]); } else if ((!x && v[x].size()==1) || (x && v[x].size()==2)) return; else if (!b[x]) { for (int j=0; j<4; ++j) for (int i=0; i<v[x].size(); ++i) if (v[x][i]!=r[x]) { if (dfs(v[x][i],j,0)) { ls[j].push_back(dfs(v[x][i],j,0)+a[x]); rs[j].push_back(dfs(v[x][i],j,0)+a[x]); } else { ls[j].push_back(dfs(v[x][i],j,0)); rs[j].push_back(dfs(v[x][i],j,0)); } if (dfs(v[x][i],j,1)) e[j].push_back(dfs(v[x][i],j,1)+a[x]); else e[j].push_back(dfs(v[x][i],j,1)); } for (int j=0; j<4; ++j) { LL now=ls[j][0]; for (int i=1; i<ls[j].size(); ++i) if (ls[j][i]) { if (ls[j][i]>now) now=ls[j][i]; else ls[j][i]=now; } now=rs[j][rs[j].size()-1]; for (int i=ls[j].size()-2; i>=0; --i) if (rs[j][i]) { if (rs[j][i]>now) now=rs[j][i]; else rs[j][i]=now; } } for (int i=0; i<=C; ++i) for (int k=0; k<e[i].size(); ++k) for (int j=0; j<=C-i; ++j) if (e[i][k] && ((k && ls[j][k-1]) || (k<rs[j].size()-1 && rs[j][k+1]))) { if (C && j==C) continue; LL u=0; if (k) u=max(u,ls[j][k-1]); if (k<rs[j].size()-1) u=max(u,rs[j][k+1]); res=max(res,u+e[i][k]-a[x]); } } else { for (int j=0; j<4; ++j) for (int i=0; i<v[x].size(); ++i) if (v[x][i]!=r[x]) { if (dfs(v[x][i],j,0)) { ls[j].push_back(dfs(v[x][i],j,0)+a[x]); rs[j].push_back(dfs(v[x][i],j,0)+a[x]); } else { ls[j].push_back(dfs(v[x][i],j,0)); rs[j].push_back(dfs(v[x][i],j,0)); } if (dfs(v[x][i],j,1)) e[j].push_back(dfs(v[x][i],j,1)+a[x]); else e[j].push_back(dfs(v[x][i],j,1)); } for (int j=0; j<4; ++j) { LL now=ls[j][0]; for (int i=1; i<ls[j].size(); ++i) if (ls[j][i]) { if (ls[j][i]>now) now=ls[j][i]; else ls[j][i]=now; } now=rs[j][rs[j].size()-1]; for (int i=ls[j].size()-2; i>=0; --i) if (rs[j][i]) { if (rs[j][i]>now) now=rs[j][i]; else rs[j][i]=now; } } for (int i=0; i<C; ++i) for (int k=0; k<e[i].size(); ++k) for (int j=0; j<=C-i-1; ++j) if (e[i][k] && ((k && ls[j][k-1]) || (k<rs[j].size()-1 && rs[j][k+1]))) { if (j==C-1) continue; LL u=0; if (k) u=max(u,ls[j][k-1]); if (k<rs[j].size()-1) u=max(u,rs[j][k+1]); res=max(res,u+e[i][k]-a[x]); } } } int main() { int T; // freopen("1006.in","r",stdin); // freopen("out.txt","w",stdout); scanf("%d",&T); while (T--) { int x,y; scanf("%d%d",&N,&C); for (int i=0; i<N; ++i) scanf("%I64d%d",&a[i],&b[i]); for (int i=0; i<N; ++i) v[i].clear(); for (int i=1; i<N; ++i) { scanf("%d%d",&x,&y); v[x].push_back(y); v[y].push_back(x); } r[0]=-1; root(0); res=0; memset(vis,0,sizeof(vis)); dfs(0,0,0); /*for (int i=0; i<N; ++i) for (int j=0; j<4; ++j) for (int k=0; k<2; ++k) printf("dp[%d][%d][%d]=%lld\n",i,j,k,dp[i][j][k]);*/ printf("%I64d\n",res); } return 0; } LL dfs(int x,int y,int z) { if (vis[x][y][z]) return dp[x][y][z]; for (int i=0; i<4; ++i) for (int j=0; j<2; ++j) { vis[x][i][j]=true; dp[x][i][j]=0; } if (v[x].size()==1 && x) { if (!b[x]) dp[x][0][0]=dp[x][0][1]=a[x]; else dp[x][1][0]=dp[x][1][1]=a[x]; } else if (!b[x]) { for (int i=0; i<v[x].size(); ++i) if (v[x][i]!=r[x]) for (int j=0; j<4; ++j) { dp[x][j][0]=max(dp[x][j][0],dfs(v[x][i],j,0)); dp[x][j][1]=max(dp[x][j][1],dfs(v[x][i],j,1)); } for (int i=0; i<4; ++i) for (int j=0; j<2; ++j) if (dp[x][i][j]) dp[x][i][j]+=a[x]; } else { for (int i=0; i<v[x].size(); ++i) if (v[x][i]!=r[x]) for (int j=1; j<4; ++j) { dp[x][j][0]=max(dp[x][j][0],dfs(v[x][i],j-1,0)); dp[x][j][1]=max(dp[x][j][1],dfs(v[x][i],j-1,1)); } for (int i=1; i<4; ++i) for (int j=0; j<2; ++j) if (dp[x][i][j]) dp[x][i][j]+=a[x]; dp[x][1][1]=a[x]; if (!dp[x][1][0]) dp[x][1][0]=a[x]; } update(x); if (!b[x]) { for (int i=0; i<C; ++i) res=max(res,dp[x][i][0]); for (int i=0;i<=C;++i) res=max(res,dp[x][i][1]); } else { for (int i=1;i<=C;++i) res=max(res,dp[x][i][0]); for (int i=1;i<=C;++i) res=max(res,dp[x][i][1]); } return dp[x][y][z]; }