比赛的时候想直接了一个 dp 的方法,令 :
dp[i][j] 表示 i 的子树内选 j 条边取 a 值的最小直径
tp[i][j] 表示 i 的子树内选 j 条边取 a 值时,距离 i 最远的叶子的最小距离
转移方程不难推,利用 tp[i][j]
和 dp[v][j]
可以转移得到 dp[i][j]
仔细一想其实这个状态不够严谨,会产生后效性,因为 tp[i][j]
取最小时 dp[i][j]
不一定取最小,如此通过子节点的 tp 和 dp
来转移。
令最大距离最小,很容易想到二分答案,二分答案 x x x,对 t p [ i ] [ j ] tp[i][j] tp[i][j] 的定义稍作修改:令 d p [ i ] [ j ] dp[i][j] dp[i][j] 表示 i i i 的子树内,选 j 条边取 a 值,且子树内不会有距离超过 x x x 的直径时,距离 i 点最远的节点的最小距离。
转移时如果能构成长度大于 x x x 的直径,则不转移。
但这题时间卡得特别紧,如果不作优化,复杂度为 O ( n k 2 log a n s ) O(nk^2\log ans) O(nk2logans),转移时取 m i n ( s i z [ u ] , k ) min(siz[u],k) min(siz[u],k),可以优化到 n k log a n s nk\log ans nklogans,使用刷表常数会比填表小,卡一卡常就过了。
总结:对这类题难点是 d p dp dp 状态的定义,定义不够严谨会导致答案错误,还会浪费大量的时间去编码和调试,而比较严谨的定义需要大胆的尝试,尤其是这种带有限制条件的状态。
代码:
#include
using namespace std;
const int maxn = 2e4 + 10;
typedef long long ll;
int n, t, k;
int head[maxn], to[maxn << 1], nxt[maxn << 1], cnt, son[maxn];
ll a[maxn << 1], b[maxn << 1];
template<typename elemType>
inline void Read(elemType &T){
elemType X=0,w=0; char ch=0;
while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
T=(w?-X:X);
}
void init() {
cnt = 0;
for (int i = 0; i <= n; i++)
head[i] = -1;
}
void add(int u,int v,ll ai,ll bi) {
to[cnt] = v;
a[cnt] = ai;
b[cnt] = bi;
nxt[cnt] = head[u];
head[u] = cnt++;
}
ll dp[22][maxn], tp[22];
void dfs(int u,int fa,ll x) {
son[u] = 0;
for (int i = 0; i <= k; i++)
dp[i][u] = 0;
for (int cur = head[u]; cur + 1; cur = nxt[cur]) {
int v = to[cur];
if (v == fa) continue;
dfs(v,u,x);
int num = min(son[u] + son[v] + 1,k);
for (int i = 0; i <= num; i++)
tp[i] = dp[i][u], dp[i][u] = x + 1;
for (int i = 0; i <= son[u]; i++) {
for (int j = 0; j <= son[v] && i + j <= k; j++) {
if (tp[i] + dp[j][v] + b[cur] <= x)
dp[i + j][u] = min(dp[i + j][u],max(dp[j][v] + b[cur],tp[i]));
if (tp[i] + dp[j][v] + a[cur] <= x) {
dp[i + j + 1][u] = min(dp[i + j + 1][u],max(dp[j][v] + a[cur],tp[i]));
}
}
}
son[u] = num;
}
}
bool check(ll x) {
dfs(1,0,x);
return dp[k][1] <= x;
}
ll solve(ll l,ll r) {
while (l < r) {
ll mid = l + r >> 1ll;
if (check(mid)) r = mid;
else l = mid + 1;
}
return l;
}
int main() {
Read(t);
while (t--) {
Read(n); Read(k);
init();
ll sum = 0;
for (int i = 1; i < n; i++) {
int x, y; ll a, b;
Read(x); Read(y); Read(a); Read(b);
add(x,y,a,b);
add(y,x,a,b);
sum += max(a,b);
}
printf("%lld\n",solve(1,sum));
}
return 0;
}