传送门:【HDU】4616 Game
题目分析:首先,看到这道题,比较容易想到这需要树上的算法,再看看c的范围那么小,那么我们的思路便可以往树型DP上靠拢。
一开始,设dp[u][i]表示从u点出发恰好经过i个trap的最大价值,然后用树型DP维护dp[u][i],同时求得最大的ans。但是,dp[u][i]是在包括i个trap以后还可以延伸一段无trap的距离的,而题目要求最多走i个trap就直接终止,所以这样当我们便不能直接将u的两个子树上的最长路径dp[v1][x],dp[v2][y]直接相加了。
所以,我们需要给dp多一些限制,让dp变成三维来实现更多的状态。
设dp[u][i][j],当j=0时表示从u点出发恰好经过i个trap且最终只能停在第i个有trap的位置上的最大价值和(当i=0时可以停在自己的位置,但信息不会用来更新u的父节点的值);当j=1时便和二维时的意义相同,表示从u点出发恰好经过i个trap且还能走一段无trap的路径的最大价值和。
这样我们对于两个子树的路径就可以合成一条合法路径来更新ans了。
当相加的两条路径的trap加上节点u上的trap值小于c,显然dp[][][1]+dp[][][1]最优。
当相加的两条路径的trap加上节点u上的trap值等于c,则只有一端可以用dp[][][1],另一端必须dp[][][0]。
具体可以见代码~
代码如下:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std ; typedef long long LL ; #define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i ) #define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i ) #define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i ) #define clr( a , x ) memset ( a , x , sizeof a ) const int MAXN = 50005 ; const int MAXE = 100005 ; struct Edge { int v , n ; Edge () {} Edge ( int v , int n ) : v ( v ) , n ( n ) {} } ; Edge E[MAXE] ; int H[MAXN] , cntE ; int dp[MAXN][4][2] ; int val[MAXN] , trap[MAXN] ; int n , c ; int ans ; void clear () { cntE = 0 ; clr ( H , -1 ) ; clr ( dp , 0 ) ; } void addedge ( int u , int v ) { E[cntE] = Edge ( v , H[u] ) ; H[u] = cntE ++ ; } void dfs ( int u , int p ) { dp[u][trap[u]][1] = dp[u][trap[u]][0] = val[u] ; ans = max ( ans , val[u] ) ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( v == p ) continue ; dfs ( v , u ) ; For ( i , 0 , c ) { For ( j , 0 , c - i ) { if ( i + j == c ) { if ( j < c ) ans = max ( ans , dp[u][i][0] + dp[v][j][1] ) ; if ( i < c ) ans = max ( ans , dp[u][i][1] + dp[v][j][0] ) ; } else ans = max ( ans , dp[u][i][1] + dp[v][j][1]) ; } } For ( i , 0 , c - trap[u] ) { int j = i + trap[u] ; if ( i > 0 ) dp[u][j][0] = max ( dp[u][j][0] , dp[v][i][0] + val[u] ) ; if ( i < c ) dp[u][j][1] = max ( dp[u][j][1] , dp[v][i][1] + val[u] ) ; } } } void scanf ( int &x , char c = 0 ) { while ( ( c = getchar () ) < '0' || c > '9' ) ; x = c - '0' ; while ( ( c = getchar () ) >= '0' && c <= '9' ) x = x * 10 + c - '0' ; } void solve () { int u , v ; clear () ; scanf ( n ) ; scanf ( c ) ; rep ( i , 0 , n ) { scanf ( val[i] ) ; scanf ( trap[i] ) ; } rep ( i , 1 , n ) { scanf ( u ) ; scanf ( v ) ; addedge ( u , v ) ; addedge ( v , u ) ; } ans = 0 ; dfs ( 0 , 0 ) ; printf ( "%d\n" , ans ) ; } int main () { int T ; scanf ( T ) ; while ( T -- ) solve () ; return 0 ; }