题目链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1766
我们需要知道一个很重要的结论:对于两个点集合S和V,S U V的最远点对一定是S的最远点对或V的最远点对或S的最远点对与V的最远点对的匹配(也就是C(4 2) == 6),那么为了把这种集合给标准化、格式化,就像二分一样,明明我们不一定要去平分区间,但为了充分考虑各种情况我们选择取中点。
注意:对于各个子区间,我们需要讨论六种情况,而对于最后的那个两个大区间的合并,根据题意
请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
所以这里我们要减去直接取区间内部值的两种情况
于是这里我们使用线段树来维护(非树剖,仅仅直接对点的编号建立线段树)
话说我还是用倍增求LCA也卡过去了???下面的是采用的倍增求LCA做法,可采用ST表求LCA,将log查询降到O1
AC Code:
#include
#define rg register
#define il inline
#define maxn 200005
#define ll long long
#define lid id << 1
#define rid (id << 1) | 1
#define rep(a,b,c) for (rg int a = 1 ; a <= c ; a += b)
using namespace std;
il int read(){rg int x = 0 , w = 1;rg char ch = getchar();while (ch < '0' || ch > '9'){if (ch == '-') w = -1;ch = getchar();}while (ch >= '0' && ch <= '9'){x = (x << 3) + (x << 1) + ch - '0';ch = getchar();}return x * w;}
int head[maxn] , cnt , dep[maxn] , f[maxn][23];
ll len[maxn];
struct edge{
int fr , to , next;
ll v;
}e[maxn << 1];
struct tree{
int l , r , p1 , p2;
ll max;
}t[maxn << 2];
struct node{
int p1 , p2;
ll max;
}len0[10];
bool cmp(node a,node b){
return a.max > b.max;
}
void add(int u,int v,ll w){
e[++cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt;
e[cnt].v = w;
e[cnt].fr = u;
}
void dfs(int u,int pre,int depth){
//q[++tot] = u;
dep[u] = depth;
f[u][0] = pre;
for (rg int i = 1 ; i <= 18 ; ++i)
f[u][i] = f[f[u][i - 1]][i - 1];
for (rg int i = head[u] ; i ; i = e[i].next){
rg int to = e[i].to;
if (to != pre){
len[to] = len[u] + e[i].v;
dfs(to , u , depth + 1);
}
}
}
int lca(int a , int b){
if (dep[a] > dep[b]) swap(a , b);
for (rg int i = 18 ; i >= 0 ; --i)
if (dep[b] - dep[a] >= (1 << i)) b = f[b][i];
if (a == b) return a;
for (rg int i = 18 ; i >= 0 ; --i)
if (f[a][i] != f[b][i])
a = f[a][i] , b = f[b][i];
return f[a][0];
}
void pushup(int id){
len0[1].max = len[t[lid].p1] + len[t[rid].p1] - 2 * len[lca(t[lid].p1 , t[rid].p1)] , len0[1].p1 = t[lid].p1 , len0[1].p2 = t[rid].p1;
len0[2].max = len[t[lid].p1] + len[t[rid].p2] - 2 * len[lca(t[lid].p1 , t[rid].p2)] , len0[2].p1 = t[lid].p1 , len0[2].p2 = t[rid].p2;
len0[3].max = len[t[lid].p2] + len[t[rid].p1] - 2 * len[lca(t[lid].p2 , t[rid].p1)] , len0[3].p1 = t[lid].p2 , len0[3].p2 = t[rid].p1;
len0[4].max = len[t[lid].p2] + len[t[rid].p2] - 2 * len[lca(t[lid].p2 , t[rid].p2)] , len0[4].p1 = t[lid].p2 , len0[4].p2 = t[rid].p2;
len0[5].max = t[lid].max , len0[5].p1 = t[lid].p1 , len0[5].p2 = t[lid].p2;
len0[6].max = t[rid].max , len0[6].p1 = t[rid].p1 , len0[6].p2 = t[rid].p2;
rg int k = 1;
for (rg int i = 2 ; i <= 6 ; ++i)
if (len0[i].max > len0[k].max)
k = i;
t[id].p1 = len0[k].p1 , t[id].p2 = len0[k].p2 , t[id].max = len0[k].max;
}
void build(int id,int l,int r){
t[id].l = l , t[id].r = r;
if (l == r) {t[id].p1 = t[id].p2 = r , t[id].max = 0;return;}
rg int mid = (l + r) >> 1;
build(lid , l , mid);
build(rid , mid + 1 , r);
pushup(id);
}
node query(int id,int l,int r){
if (t[id].l == l && t[id].r == r){
return (node){t[id].p1 , t[id].p2 , t[id].max};
}
rg int mid = (t[id].l + t[id].r) >> 1;
if (r <= mid) return query(lid , l , r);
else if (l > mid) return query(rid , l , r);
else{
node a = query(lid , l , mid) , b = query(rid , mid + 1 , r);
if (!a.max && !b.max) return (node){a.p1 , b.p1 , len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)]};
len0[1].max = len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)] , len0[1].p1 = a.p1 , len0[1].p2 = b.p1;
len0[2].max = len[a.p1] + len[b.p2] - 2 * len[lca(a.p1 , b.p2)] , len0[2].p1 = a.p1 , len0[2].p2 = b.p2;
len0[3].max = len[a.p2] + len[b.p1] - 2 * len[lca(a.p2 , b.p1)] , len0[3].p1 = a.p2 , len0[3].p2 = b.p1;
len0[4].max = len[a.p2] + len[b.p2] - 2 * len[lca(a.p2 , b.p2)] , len0[4].p1 = a.p2 , len0[4].p2 = b.p2;
len0[5] = a , len0[6] = b;
rg int k = 1;
for (rg int i = 2 ; i <= 6 ; ++i)
if (len0[i].max > len0[k].max)
k = i;
return len0[k];
}
}
int main(){
rg int n = read() , u , v;
ll w;
for (rg int i = 1 ; i < n ; ++i){
u = read() , v = read();
scanf("%lld" , &w);
add(u , v , w);
add(v , u , w);
}
dfs(1 , 0 , 1);
rg int q = read();
build(1 , 1 , n);
int l1 , r1 , l2 , r2;
for (rg int i = 1 ; i <= q ; ++i){
l1 = read() , r1 = read() , l2 = read() , r2 = read();
node a = query(1 , l1 , r1);
node b = query(1 , l2 , r2);
len0[1].max = len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)]; //, len0[1].p1 = a.p1 , len0[1].p2 = b.p1;
len0[2].max = len[a.p1] + len[b.p2] - 2 * len[lca(a.p1 , b.p2)]; //, len0[2].p1 = a.p1 , len0[2].p2 = b.p2;
len0[3].max = len[a.p2] + len[b.p1] - 2 * len[lca(a.p2 , b.p1)]; //, len0[3].p1 = a.p2 , len0[3].p2 = b.p1;
len0[4].max = len[a.p2] + len[b.p2] - 2 * len[lca(a.p2 , b.p2)]; //, len0[4].p1 = a.p2 , len0[4].p2 = b.p2;
rg int k = 1;
for (rg int i = 2 ; i <= 4 ; ++i)
if (len0[i].max > len0[k].max)
k = i;
printf("%lld\n" , len0[k].max);
}
return 0;
}```