【题目链接】
大概并没有用正解,卡常数A了。
先二分答案mid,那么边权变为0的边一定在所有长度大于mid的路径的交上,且这条边的边权至少为路径长度减去mid(否则不可能长度不可能降低到mid之下)。
每次check,把长度大于mid的路径加进去,求一次交(一次dfs就求出来了),遍历每条边,看是否有边满足上述条件。
/* Pigonometry */ #include <cstdio> #include <algorithm> using namespace std; const int maxn = 300005, maxm = 300005, maxk = 20; int n, m, head[maxn], cnt; struct _edge { int v, w, next; } g[maxn << 1]; struct _line { int u, v, lca, dis; } line[maxm]; inline int iread() { int f = 1, x = 0; char ch = getchar(); for(; ch < '0' || ch > '9'; ch = getchar()) f = ch == '-' ? -1 : 1; for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0'; return f * x; } inline void add(int u, int v, int w) { g[cnt] = (_edge){v, w, head[u]}; head[u] = cnt++; } /* doubling lca */ int pre[maxn][maxk], val[maxn], depth[maxn], dis[maxn]; inline void dfs(int x) { for(int i = head[x]; ~i; i = g[i].next) if(g[i].v ^ pre[x][0]) { pre[g[i].v][0] = x; val[g[i].v] = g[i].w; depth[g[i].v] = depth[x] + 1; dis[g[i].v] = dis[x] + g[i].w; dfs(g[i].v); } } inline int getlca(int u, int v) { if(depth[u] < depth[v]) swap(u, v); for(int i = maxk - 1; i >= 0; i--) if(depth[pre[u][i]] >= depth[v]) u = pre[u][i]; for(int i = maxk - 1; i >= 0; i--) if(pre[u][i] != pre[v][i]) u = pre[u][i], v = pre[v][i]; return u == v ? u : pre[u][0]; } int sum[maxn]; inline void work(int x) { for(int i = head[x]; ~i; i = g[i].next) if(g[i].v ^ pre[x][0]) { work(g[i].v); sum[x] += sum[g[i].v]; } } inline bool check(int x) { int tot = 0, w = 0; for(int i = 1; i <= n; i++) sum[i] = 0; for(int i = 1; i <= m; i++) if(line[i].dis > x) { tot++; sum[line[i].u]++; sum[line[i].v]++; sum[line[i].lca] -= 2; w = max(w, line[i].dis - x); } work(1); for(int i = 1; i <= n; i++) if(sum[i] == tot && val[i] >= w) return 1; return 0; } int main() { n = iread(); m = iread(); for(int i = 1; i <= n; i++) head[i] = -1; cnt = 0; for(int i = 1; i < n; i++) { int u = iread(), v = iread(), w = iread(); add(u, v, w); add(v, u, w); } dfs(1); for(int j = 1; j < maxk; j++) for(int i = 1; i <= n; i++) pre[i][j] = pre[pre[i][j - 1]][j - 1]; int l = 0, r = 0; for(int i = 1; i <= m; i++) { int u = iread(), v = iread(), lca = getlca(u, v), d = dis[u] + dis[v] - 2 * dis[lca]; line[i] = (_line){u, v, lca, d}; r = max(r, d); } while(l <= r) { int mid = l + r >> 1; if(check(mid)) r = mid - 1; else l = mid + 1; } printf("%d\n", l); return 0; }