HDU6820 Tree 树形dp

题目描述

给你一棵节点数为 N N N的无根树,每条边有权值,要求选出一棵联通的图,图里面至多允许有一个度数大于 K K K的点,而且要使得边权和尽量大
N ≤ 1 0 5 , 0 ≤ k < n N\leq 10^5,0 \leq kN105,0k<n

分析

我们可以想用树形dp,维护每走到一个点,当前点的度数就可以大于K的最大答案。
那么就是当前点可以连边到所有的点,但是与其连边的子树内都不能有度数大于K的点。
于是可以用 d p [ i ] dp[i] dp[i]表示 i i i向所有自己子树连边的最大价值
f [ i ] f[i] f[i]表示 i i i只能向子树内连 K − 1 K-1 K1条边(如果i是根节点,则可以连 K K K条边)的最大价值
所以就有:
d p [ x ] = ∑ f [ y ] + d i s ( x , y ) dp[x] = \sum\limits f[y] + dis(x,y) dp[x]=f[y]+dis(x,y)
f [ x ] = max ⁡ i = 1 k f [ y ] + d i s ( x , y ) f[x] = \max\limits_{i=1}^{k} f[y] + dis(x,y) f[x]=i=1maxkf[y]+dis(x,y)
然后向其他点转移的时候,同样道理统计子树外的情况就好了,这里涉及到一些细节,可以想一想。

代码

#include 
#define int long long
#define MP make_pair
#define PB push_back
#define CL clear
#define fi first
#define se second
using namespace std;
typedef pair<int,int> pii;
const int N = 2e5+10;
inline int rd() {
  int p=0; int f=1; char ch=getchar();
  while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
  while(ch>='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
  return p*f;
}
struct node{int x,y,nex,d;}edge[N<<1]; int len,fir[N];
void ins(int x,int y,int d){len++; edge[len].x=x; edge[len].y=y; edge[len].d=d; edge[len].nex = fir[x]; fir[x] = len;}
int f[N],g[N],dp[N]; bool bo[N];
bool cmp(const pii &x,const pii &y){return x>y;}
signed main() {
  int t = rd();
  while(t--) {
    int n = rd(); int k = rd();
    len = 0; for(int i=1;i<=n;i++) fir[i] = -1;
    for(int i=1;i<n;i++){
      int x = rd(); int y = rd(); int d = rd();
      ins(x,y,d); ins(y,x,d);
    }if(k==0){puts("0"); continue;}

    for(int i=1;i<=n;i++) f[i] = g[i] = dp[i] = 0;

    function<void(int,int)>dfs=[&](int x,int fa) {
      vector<pii> q;
      for(int k=fir[x];k!=-1;k=edge[k].nex) {
        int y = edge[k].y;
        if(y==fa) continue;
        dfs(y,x);
      }q.CL();
      for(int k=fir[x];k!=-1;k=edge[k].nex) {
        int y = edge[k].y;
        if(y==fa) continue;
        q.PB(MP(f[y]+edge[k].d,y)); dp[x] += f[y]+edge[k].d;
      }sort(q.begin(),q.end(),cmp);
      if(x==1) for(int i=0;i<min((int)q.size(),k);i++) f[x] += q[i].fi;
      else for(int i=0;i<min((int)q.size(),k-1);i++) f[x] += q[i].fi;
    };
    dfs(1,0);
    
    int ans = 0;
    function<void(int,int,int)>dfs2=[&](int x,int fa,int d) {
      vector<pii> q; q.CL();
      ans = max(ans , dp[x] + g[fa] + d);
      q.PB(MP(g[fa] + d , fa));
      for(int k=fir[x];k!=-1;k=edge[k].nex) {
        int y = edge[k].y;
        if(y==fa) continue;
        q.PB(MP(f[y] + edge[k].d , y));
      }
      sort(q.begin(),q.end(),cmp); int s = 0; int qk=0; if((int)q.size() >= k) qk = q[k-1].fi;
      for(int i=0;i<min(k-1,(int)q.size());i++){
        bo[q[i].se] = 1;
        s += q[i].fi;
      }
      for(int k=fir[x];k!=-1;k=edge[k].nex) {
        int y = edge[k].y;
        if(y == fa) continue;
        if(bo[y]){g[x] = s - f[y] - edge[k].d + qk; dfs2(y,x,edge[k].d);}
        else{g[x] = s; dfs2(y,x,edge[k].d);}
      }
      for(int i=0;i<min(k-1,(int)q.size());i++) {
        bo[q[i].se] = 0;
      }
    };
    dfs2(1,0,0);
    printf("%lld\n",ans);
  }
  return 0;
}

你可能感兴趣的:(dp,hdu)