N N N个点 N − 1 N-1 N−1条边的树,每条边有两个可能值 a i , b i a_i,b_i ai,bi,已知有 K K K条边的权值为其对应 a i a_i ai,其余的 N − K + 1 N-K+1 N−K+1条边的权值为对应的 b i b_i bi,现在确定 K K K条边使得树的直径最小时的答案
树的直径为树上最长的路径,本题意在最小化最大值,通常这种问题都可以用二分答案解决
现在考虑如何快速判断能否在确定答案 a n s ans ans时,恰好选 K K K条边使得当前树的直径小于 ≤ a n s \leq ans ≤ans
对于一个确定答案 a n s ans ans,本是要求整棵树的最长路径 ≤ a n s \leq ans ≤ans
换句话说,树上任意一条路径都 ≤ a n s \leq ans ≤ans,这样看似增加了求解难度
但是实际上只需要满足对于每一个点,其子树中经过他的最长路径最短即可,这样就不需要求树的直径
原因可以参考点分治中的①②类路径或者树的直径的 d p dp dp求法
考虑如何求一个点的子树中,经过该点的最长路径
记 f [ u ] [ k ] f[u][k] f[u][k]为 u u u的子树中选择 k k k条边的最长路径的最小值(这里比较绕),本次搜索的子树为 v v v
接下来就是一个树上的背包问题(类似可以参考)
但是它还带上了一个约束条件,只有满足当前最长路径长度 ≤ a n s \leq ans ≤ans才能更新 f f f
dfs(v, u, lim);
int now = min(siz[u] + siz[v] + 1, K);
ll t[22]; for (int j = 0; j <= now; j++) t[j] = lim + 1;
//因为现在的f存的是之前搜过的子树, 如果直接更新, 下面的if判断中就不是之前搜过的最长
//会有同一子树中两条最长路径的情况出现,所以用t数组暂时保存更新后的答案
for (int j = 0; j <= siz[u]; j++)
for (int k = 0; k <= siz[v] && j + k <= K; k++) {
//只有之前搜的子树中最长路径 + 当前子树中最长路径 <= ans才能更新
if (f[u][j] + f[v][k] + e[i].a <= lim)
t[j + k + 1] = min(t[j + k + 1], max(f[u][j], f[v][k] + e[i].a));//保证路径是最长之后使其最小
if (f[u][j] + f[v][k] + e[i].b <= lim)
t[j + k] = min(t[j + k], max(f[u][j], f[v][k] + e[i].b));
}
for (int j = 0; j <= now; j++) f[u][j] = t[j];//全部更新完才会更新f
siz[u] = now;
#include
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int MAX = 2e4 + 10;
struct edge {
int nxt, to;
ll a, b;
} e[MAX << 1];
int head[MAX], tot;
void add(int u, int v, ll a, ll b) { e[++tot] = edge{head[u], v, a, b}; head[u] = tot; }
int N, K;
int siz[MAX];
ll f[MAX][22];
void dfs(int u, int fa, ll lim) {
siz[u] = 0;
for (int i = 0; i <= K; i++) f[u][i] = 0;
for (int i = head[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa) {
dfs(v, u, lim);
int now = min(siz[u] + siz[v] + 1, K);
ll t[22]; for (int j = 0; j <= now; j++) t[j] = lim + 1;
for (int j = 0; j <= siz[u]; j++)
for (int k = 0; k <= siz[v] && j + k <= K; k++) {
if (f[u][j] + f[v][k] + e[i].a <= lim)
t[j + k + 1] = min(t[j + k + 1], max(f[u][j], f[v][k] + e[i].a));
if (f[u][j] + f[v][k] + e[i].b <= lim)
t[j + k] = min(t[j + k], max(f[u][j], f[v][k] + e[i].b));
}
for (int j = 0; j <= now; j++) f[u][j] = t[j];
siz[u] = now;
}
}
int main() {
int T; scanf("%d", &T);
while (T--) {
scanf("%d%d", &N, &K);
ll l = 1, r = 0;
for (int i = 1; i < N; i++) {
int u, v; ll a, b; scanf("%d%d%lld%lld", &u, &v, &a, &b);
add(u, v, a, b); add(v, u, a, b);
r += max(a, b);
}
ll ans = r;
while (l <= r) {
ll mid = (l + r) / 2;
dfs(1, 0, mid);
if (f[1][K] <= mid) r = mid - 1, ans = mid;
else l = mid + 1;
}
printf("%lld\n", ans);
for (int i = 1; i <= tot; i++) head[i] = 0; tot = 0;
}
return 0;
}