给你一棵节点数为 N N N的无根树,每条边有权值,要求选出一棵联通的图,图里面至多允许有一个度数大于 K K K的点,而且要使得边权和尽量大
N ≤ 1 0 5 , 0 ≤ k < n N\leq 10^5,0 \leq k
我们可以想用树形dp,维护每走到一个点,当前点的度数就可以大于K的最大答案。
那么就是当前点可以连边到所有的点,但是与其连边的子树内都不能有度数大于K的点。
于是可以用 d p [ i ] dp[i] dp[i]表示 i i i向所有自己子树连边的最大价值
f [ i ] f[i] f[i]表示 i i i只能向子树内连 K − 1 K-1 K−1条边(如果i是根节点,则可以连 K K K条边)的最大价值
所以就有:
d p [ x ] = ∑ f [ y ] + d i s ( x , y ) dp[x] = \sum\limits f[y] + dis(x,y) dp[x]=∑f[y]+dis(x,y)
f [ x ] = max i = 1 k f [ y ] + d i s ( x , y ) f[x] = \max\limits_{i=1}^{k} f[y] + dis(x,y) f[x]=i=1maxkf[y]+dis(x,y)
然后向其他点转移的时候,同样道理统计子树外的情况就好了,这里涉及到一些细节,可以想一想。
#include
#define int long long
#define MP make_pair
#define PB push_back
#define CL clear
#define fi first
#define se second
using namespace std;
typedef pair<int,int> pii;
const int N = 2e5+10;
inline int rd() {
int p=0; int f=1; char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
while(ch>='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
return p*f;
}
struct node{int x,y,nex,d;}edge[N<<1]; int len,fir[N];
void ins(int x,int y,int d){len++; edge[len].x=x; edge[len].y=y; edge[len].d=d; edge[len].nex = fir[x]; fir[x] = len;}
int f[N],g[N],dp[N]; bool bo[N];
bool cmp(const pii &x,const pii &y){return x>y;}
signed main() {
int t = rd();
while(t--) {
int n = rd(); int k = rd();
len = 0; for(int i=1;i<=n;i++) fir[i] = -1;
for(int i=1;i<n;i++){
int x = rd(); int y = rd(); int d = rd();
ins(x,y,d); ins(y,x,d);
}if(k==0){puts("0"); continue;}
for(int i=1;i<=n;i++) f[i] = g[i] = dp[i] = 0;
function<void(int,int)>dfs=[&](int x,int fa) {
vector<pii> q;
for(int k=fir[x];k!=-1;k=edge[k].nex) {
int y = edge[k].y;
if(y==fa) continue;
dfs(y,x);
}q.CL();
for(int k=fir[x];k!=-1;k=edge[k].nex) {
int y = edge[k].y;
if(y==fa) continue;
q.PB(MP(f[y]+edge[k].d,y)); dp[x] += f[y]+edge[k].d;
}sort(q.begin(),q.end(),cmp);
if(x==1) for(int i=0;i<min((int)q.size(),k);i++) f[x] += q[i].fi;
else for(int i=0;i<min((int)q.size(),k-1);i++) f[x] += q[i].fi;
};
dfs(1,0);
int ans = 0;
function<void(int,int,int)>dfs2=[&](int x,int fa,int d) {
vector<pii> q; q.CL();
ans = max(ans , dp[x] + g[fa] + d);
q.PB(MP(g[fa] + d , fa));
for(int k=fir[x];k!=-1;k=edge[k].nex) {
int y = edge[k].y;
if(y==fa) continue;
q.PB(MP(f[y] + edge[k].d , y));
}
sort(q.begin(),q.end(),cmp); int s = 0; int qk=0; if((int)q.size() >= k) qk = q[k-1].fi;
for(int i=0;i<min(k-1,(int)q.size());i++){
bo[q[i].se] = 1;
s += q[i].fi;
}
for(int k=fir[x];k!=-1;k=edge[k].nex) {
int y = edge[k].y;
if(y == fa) continue;
if(bo[y]){g[x] = s - f[y] - edge[k].d + qk; dfs2(y,x,edge[k].d);}
else{g[x] = s; dfs2(y,x,edge[k].d);}
}
for(int i=0;i<min(k-1,(int)q.size());i++) {
bo[q[i].se] = 0;
}
};
dfs2(1,0,0);
printf("%lld\n",ans);
}
return 0;
}