(赛前练手#1)51nod1766 树上的最远点对(线段树 + LCA)

题目链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1766

我们需要知道一个很重要的结论:对于两个点集合S和V,S U V的最远点对一定是S的最远点对或V的最远点对或S的最远点对与V的最远点对的匹配(也就是C(4 2) == 6),那么为了把这种集合给标准化、格式化,就像二分一样,明明我们不一定要去平分区间,但为了充分考虑各种情况我们选择取中点。
注意:对于各个子区间,我们需要讨论六种情况,而对于最后的那个两个大区间的合并,根据题意

请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}

所以这里我们要减去直接取区间内部值的两种情况

于是这里我们使用线段树来维护(非树剖,仅仅直接对点的编号建立线段树)
话说我还是用倍增求LCA也卡过去了???下面的是采用的倍增求LCA做法,可采用ST表求LCA,将log查询降到O1
AC Code:

#include
#define rg register
#define il inline
#define maxn 200005
#define ll long long
#define lid id << 1
#define rid (id << 1) | 1
#define rep(a,b,c) for (rg int a = 1 ; a <= c ; a += b)
using namespace std;
il int read(){rg int x = 0 , w = 1;rg char ch = getchar();while (ch < '0' || ch > '9'){if (ch == '-') w = -1;ch = getchar();}while (ch >= '0' && ch <= '9'){x = (x << 3) + (x << 1) + ch - '0';ch = getchar();}return x * w;}
int head[maxn] ,  cnt , dep[maxn] , f[maxn][23];
ll len[maxn]; 
struct edge{
    int fr , to , next;
	ll v;	
}e[maxn << 1];
struct tree{
	int l , r , p1 , p2;
	ll max;
}t[maxn << 2];
struct node{
	int p1 , p2;
	ll max;	
}len0[10];
bool cmp(node a,node b){
	return a.max > b.max;	
}
void add(int u,int v,ll w){
    e[++cnt].to = v;
    e[cnt].next = head[u];
    head[u] = cnt;
    e[cnt].v = w;
    e[cnt].fr = u;	
}
void dfs(int u,int pre,int depth){
    //q[++tot] = u;
    dep[u] = depth;
    f[u][0] = pre;
    for (rg int i = 1 ; i <= 18 ; ++i)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (rg int i = head[u] ; i ; i = e[i].next){
        rg int to = e[i].to;
        if (to != pre){
        	len[to] = len[u] + e[i].v;
            dfs(to , u , depth + 1);
        }
    }
}
int lca(int a , int b){
    if (dep[a] > dep[b]) swap(a , b);
    for (rg int i = 18 ; i >= 0 ; --i)
        if (dep[b] - dep[a] >= (1 << i)) b = f[b][i];
    if (a == b) return a;
    for (rg int i = 18 ; i >= 0 ; --i)
        if (f[a][i] != f[b][i])
            a = f[a][i] , b = f[b][i];
    return f[a][0];
}
void pushup(int id){
	len0[1].max = len[t[lid].p1] + len[t[rid].p1] - 2 * len[lca(t[lid].p1 , t[rid].p1)] , len0[1].p1 = t[lid].p1 , len0[1].p2 = t[rid].p1;
	len0[2].max = len[t[lid].p1] + len[t[rid].p2] - 2 * len[lca(t[lid].p1 , t[rid].p2)] , len0[2].p1 = t[lid].p1 , len0[2].p2 = t[rid].p2;
	len0[3].max = len[t[lid].p2] + len[t[rid].p1] - 2 * len[lca(t[lid].p2 , t[rid].p1)] , len0[3].p1 = t[lid].p2 , len0[3].p2 = t[rid].p1;
	len0[4].max = len[t[lid].p2] + len[t[rid].p2] - 2 * len[lca(t[lid].p2 , t[rid].p2)] , len0[4].p1 = t[lid].p2 , len0[4].p2 = t[rid].p2;
	len0[5].max = t[lid].max , len0[5].p1 = t[lid].p1 , len0[5].p2 = t[lid].p2;
	len0[6].max = t[rid].max , len0[6].p1 = t[rid].p1  , len0[6].p2 = t[rid].p2;
	rg int k = 1;
	for (rg int i = 2 ; i <= 6 ; ++i)
		if (len0[i].max > len0[k].max)
			k = i;
	t[id].p1 = len0[k].p1 , t[id].p2 = len0[k].p2 , t[id].max = len0[k].max;
}
void build(int id,int l,int r){
	t[id].l = l , t[id].r = r;
	if (l == r) {t[id].p1 = t[id].p2 = r , t[id].max = 0;return;}
	rg int mid = (l + r) >> 1;
	build(lid , l , mid);
	build(rid , mid + 1 , r);
	pushup(id);
}
node query(int id,int l,int r){
	if (t[id].l == l && t[id].r == r){
		return (node){t[id].p1 , t[id].p2 , t[id].max};
	}
	rg int mid = (t[id].l + t[id].r) >> 1;
	if (r <= mid) return query(lid , l , r);
	else if (l > mid) return query(rid , l , r);
	else{
		node a = query(lid , l , mid) , b = query(rid , mid + 1 , r);
		if (!a.max && !b.max) return (node){a.p1 , b.p1 , len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)]};
		len0[1].max = len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)] , len0[1].p1 = a.p1 , len0[1].p2 = b.p1;
		len0[2].max = len[a.p1] + len[b.p2] - 2 * len[lca(a.p1 , b.p2)] , len0[2].p1 = a.p1 , len0[2].p2 = b.p2;
		len0[3].max = len[a.p2] + len[b.p1] - 2 * len[lca(a.p2 , b.p1)] , len0[3].p1 = a.p2 , len0[3].p2 = b.p1;
		len0[4].max = len[a.p2] + len[b.p2] - 2 * len[lca(a.p2 , b.p2)] , len0[4].p1 = a.p2 , len0[4].p2 = b.p2;
		len0[5] = a , len0[6] = b;
		rg int k = 1;
		for (rg int i = 2 ; i <= 6 ; ++i)
			if (len0[i].max > len0[k].max)
				k = i; 
		return len0[k];
	}
}
int main(){
    rg int n = read() , u , v;
	ll w;
    for (rg int i = 1 ; i < n ; ++i){
    	u = read() , v = read();
		scanf("%lld" , &w);
    	add(u , v , w);
    	add(v , u , w);
    }
    dfs(1 , 0 , 1);
    rg int q = read();
    build(1 , 1 , n);
    int l1 , r1 , l2 , r2;
    for (rg int i = 1 ; i <= q ; ++i){
    	l1 = read() , r1 = read() , l2 = read() , r2 = read();
    	node a = query(1 , l1 , r1);
		node b = query(1 , l2 , r2);
		len0[1].max = len[a.p1] + len[b.p1] - 2 * len[lca(a.p1 , b.p1)]; //, len0[1].p1 = a.p1 , len0[1].p2 = b.p1;
		len0[2].max = len[a.p1] + len[b.p2] - 2 * len[lca(a.p1 , b.p2)]; //, len0[2].p1 = a.p1 , len0[2].p2 = b.p2;
		len0[3].max = len[a.p2] + len[b.p1] - 2 * len[lca(a.p2 , b.p1)]; //, len0[3].p1 = a.p2 , len0[3].p2 = b.p1;
		len0[4].max = len[a.p2] + len[b.p2] - 2 * len[lca(a.p2 , b.p2)]; //, len0[4].p1 = a.p2 , len0[4].p2 = b.p2;
		rg int k = 1;
		for (rg int i = 2 ; i <= 4 ; ++i)
			if (len0[i].max > len0[k].max)
				k = i; 
		printf("%lld\n" , len0[k].max);	
    }
    return 0;
}```

你可能感兴趣的:(线段树,LCA,赛前练手)