单看一个询问,就是求仅考虑区间内的点的直径。
树的直径具有可合并的性质,同一棵树上两个区间的直径的两个端点分别为 ( a , b ) (a,b) (a,b), ( c , d ) (c,d) (c,d),那么合并两个区间后的新的直径的端点一定在 a , b , c , d {a,b,c,d} a,b,c,d 中,通过枚举端点计算它们的距离,取最大值可以得到两个区间合并的直径。 其正确性证明和两遍 d f s dfs dfs 求树的直径的证明过程类似。
对于这题,考虑用 s t st st 表预处理区间的直径,对于询问用类似的合并方法合并两个 st 表即可。
查询 l c a lca lca 要用欧拉序和 s t st st 表,不然会TLE。
注意:内置的 log 2 ( ) \log2() log2() 函数常数非常大,改成预处理每个数的 log 2 \log_2 log2 值可以降低常数。
时空复杂度均为 O ( n log 2 n ) O(n\log_2n) O(nlog2n)
代码:
#include
using namespace std;
const int maxn = 3e5 + 10;
#define pii pair
#define fir first
#define sec second
typedef long long ll;
int n, q;
int fir[maxn], st[maxn << 1][25], cnt, bin[maxn], lg[2 * maxn];
pii pot[maxn][25];
ll dep[maxn];
struct node {
int head[maxn], nxt[maxn << 1], cnt, to[maxn << 1], w[maxn << 1];
void init() {
memset(head,-1,sizeof head);
cnt = 0;
}
void add(int u,int v,int wi) {
to[cnt] = v;
w[cnt] = wi;
nxt[cnt] = head[u];
head[u] = cnt++;
}
} g;
void prework(int u,int fa) {
fir[u] = ++cnt; st[cnt][0] = u;
pot[u][0] = pii(u,u);
for (int i = g.head[u]; i + 1; i = g.nxt[i]) {
if (g.to[i] == fa) continue;
dep[g.to[i]] = dep[u] + g.w[i];
prework(g.to[i],u);
st[++cnt][0] = u;
}
}
int calc(int u,int v) {
return dep[u] < dep[v] ? u : v;
}
int getlca(int u,int v) {
if (fir[u] > fir[v]) swap(u,v);
int p = lg[fir[v] - fir[u] + 1];
return calc(st[fir[u]][p],st[fir[v] - bin[p] + 1][p]);
}
ll getdis(ll u,ll v) {
int lca = getlca(u,v);
return dep[u] + dep[v] - dep[lca] - dep[lca];
}
void init() {
for (int i = 0; i <= 22; i++)
bin[i] = 1 << i;
lg[1] = 0;
for (int i = 2; i <= cnt; i++)
lg[i] = lg[i >> 1] + 1;
for (int i = 1; bin[i] <= cnt; i++)
for (int j = 1; j + bin[i] - 1 <= cnt; j++)
st[j][i] = calc(st[j][i - 1],st[j + bin[i - 1]][i - 1]);
for (int i = 1; bin[i] <= n; i++)
for (int j = 1; j + bin[i] - 1 <= n; j++) {
pii x = pot[j][i - 1], y = pot[j + bin[i - 1]][i - 1];
pot[j][i] = pii(x.fir,y.fir);
if (getdis(x.fir,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
pot[j][i] = pii(x.fir,y.sec);
if (getdis(x.sec,y.fir) > getdis(pot[j][i].fir,pot[j][i].sec))
pot[j][i] = pii(x.sec,y.fir);
if (getdis(x.sec,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
pot[j][i] = pii(x.sec,y.sec);
if (getdis(x.fir,x.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
pot[j][i] = pii(x.fir,x.sec);
if (getdis(y.fir,y.sec) > getdis(pot[j][i].fir,pot[j][i].sec))
pot[j][i] = pii(y.fir,y.sec);
}
}
ll query(int l,int r) {
int p = lg[r - l + 1];
pii x = pot[l][p], y = pot[r - bin[p] + 1][p];
ll ans = getdis(x.fir,y.fir);
ans = max(ans,getdis(x.fir,y.sec));
ans = max(ans,getdis(x.sec,y.fir));
ans = max(ans,getdis(x.sec,y.sec));
ans = max(ans,getdis(x.fir,x.sec));
ans = max(ans,getdis(y.fir,y.sec));
return ans;
}
int main() {
g.init();
scanf("%d%d",&n,&q);
for (int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d",&u,&v,&w);
g.add(u,v,w);
g.add(v,u,w);
}
prework(1,0);
init();
while (q--) {
int l, r; scanf("%d%d",&l,&r);
printf("%lld\n",query(l,r));
}
return 0;
}