这题数据范围变成了200000 n^2就过不了 同时要求求的是最少的边数 不能容斥
#includeusing namespace std; const int MAXN = 2e5 + 5; const int MAXM = 2e5 + 5; int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], ed = 1; int cost[MAXM << 1]; int ok[1000005]; inline void addedge(int u, int v, int c) { to[++ed] = v; cost[ed] = c; nxt[ed] = Head[u]; Head[u] = ed; } inline void ADD(int u, int v, int c) { addedge(u, v, c); addedge(v, u, c); } int n, anser, k, cnt; int sz[MAXN], f[MAXN], dep[MAXN], sumsz, root; bool vis[MAXN]; pair<int, int> o[MAXN]; int num[MAXN]; void getroot(int x, int fa) { sz[x] = 1; f[x] = 0; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (v == fa || vis[v]) { continue; } getroot(v, x); sz[x] += sz[v]; f[x] = max(f[x], sz[v]); } f[x] = max(f[x], sumsz - sz[x]); if (f[x] < f[root]) { root = x; } } void getdeep(int x, int fa, int deep) { if (dep[x] > k) { return; } o[++cnt] = make_pair(dep[x], deep); num[++num[0]] = dep[x]; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (v == fa || vis[v]) { continue; } dep[v] = dep[x] + cost[i]; getdeep(v, x, deep + 1); } } void calc(int x, int d) { num[0] = 0; dep[x] = d; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (vis[v]) { continue; } cnt = 0; dep[v] = dep[x] + cost[i]; getdeep(v, x, 1); for (int j = 1; j <= cnt; j++) { if (o[j].first <= k) { if (ok[k - o[j].first] != INT_MAX) { anser = min(anser, ok[k - o[j].first] + o[j].second); } } } for (int j = 1; j <= cnt; j++) { if (o[j].first <= k) { ok[o[j].first] = min(o[j].second, ok[o[j].first]); } } } for (int i = 1; i <= num[0]; i++) { ok[num[i]] = INT_MAX; } } void solve(int x) { vis[x] = 1; calc(x, 0); int totsz = sumsz; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (vis[v]) { continue; } root = 0; sumsz = sz[v] > sz[x] ? totsz - sz[x] : sz[v]; getroot(v, 0); solve(root); } } int main() { scanf("%d %d", &n, &k); anser = INT_MAX; for (int i = 1; i <= k; i++) { ok[i] = INT_MAX; } cnt = 0; memset(Head, 0, sizeof(Head)); memset(vis, 0, sizeof(vis)); ed = 1; int u, v, c; for (int i = 1; i < n; i++) { scanf("%d %d %d", &u, &v, &c); ADD(u + 1, v + 1, c); } root = 0, sumsz = f[0] = n; getroot(1, 0); solve(root); printf("%d\n", anser == INT_MAX ? -1 : anser); return 0; }
注意dep[x]>k的时候要return 不然会re