传送门
这是一道树链剖分的基础题目,只需要用线段树来维护重链上的节点信息,轻链一条一条爬就行了,在查询的时候可以先求出LCA然后查询两个点到LCA的信息然后合并就行了。
代码:
/************************************************************** Problem: 1036 User: geng4512 Language: C++ Result: Accepted Time:3064 ms Memory:5188 kb ****************************************************************/
#include<cstdio>
#define MAXN 30005
struct node {int v; node *nxt;} Edge[MAXN<<1], *Adj[MAXN], *M = Edge;
inline void Add(int u, int v) { ++ M; M->v = v; M->nxt = Adj[u]; Adj[u] = M; }
inline int Max(int a, int b) {return a > b ? a : b;}
struct Seg{int l, r, mx, sum;} t[MAXN<<2];
char c, f;
inline void GET(int &n) {
f = 1; n = 0; do{c = getchar(); if(c == '-') f = -1; }while('0' > c || c > '9');
while('0' <= c && c <= '9') {n=n*10+c-'0';c=getchar();} n *= f;
}
int n, fa[MAXN], dep[MAXN], dcnt, pos[MAXN];
int sz[MAXN], htp[MAXN], hsn[MAXN], a[MAXN];
int px[MAXN], q, rev[MAXN];
/************ 线段树相关 ******************/
void Build(int i, int l, int r) {
t[i].l = l; t[i].r = r;
if(l == r) {
px[l] = i; t[i].sum = t[i].mx = a[rev[l]]; return;
}
int mid = (l+r)>>1;
Build(i<<1, l, mid); Build(i<<1|1, mid+1, r);
t[i].mx = Max(t[i<<1].mx, t[i<<1|1].mx);
t[i].sum = t[i<<1].sum + t[i<<1|1].sum;
}
void Ins(int x, int v) {
int p = px[x]; t[p].mx = t[p].sum = v; p >>= 1;
while(p) {
t[p].mx = Max(t[p<<1].mx, t[p<<1|1].mx);
t[p].sum = t[p<<1].sum + t[p<<1|1].sum;
p >>= 1;
}
}
int Qmx(int i, int L, int R) {
if(t[i].r < L || t[i].l > R) return -0x3f3f3f3f;
if(L <= t[i].l && t[i].r <= R) return t[i].mx;
return Max(Qmx(i<<1, L, R), Qmx(i<<1|1, L, R));
}
int Qsm(int i, int L, int R) {
if(t[i].r < L || t[i].l > R) return 0;
if(L <= t[i].l && t[i].r <= R) return t[i].sum;
return Qsm(i<<1, L, R) + Qsm(i<<1|1, L, R);
}
/****************************************/
int tmp;
inline void Swap(int &a, int &b) {tmp = a; a = b; b = tmp;}
inline int LCA(int a, int b) {
while(htp[a] != htp[b]) {
if(dep[htp[a]] < dep[htp[b]]) Swap(a, b);
a = fa[htp[a]];
}
return dep[a] < dep[b] ? a : b;
}
inline int sum(int a, int b) {
int sum = 0;
while(htp[a] != htp[b]) { //查重链,爬轻链
sum += Qsm(1, pos[htp[a]], pos[a]);
a = fa[htp[a]];
}
return sum += Qsm(1, pos[b], pos[a]);
}
inline int Mx(int a, int b) {
int Mx = -0x3f3f3f3f;
while(htp[a] != htp[b]) {
Mx = Max(Mx, Qmx(1, pos[htp[a]], pos[a]));
a = fa[htp[a]];
}
Mx = Max(Qmx(1, pos[b], pos[a]), Mx);
return Mx;
}
/*************** 轻重链剖分 ***************/
void dfs1(int u) {
sz[u] = 1;
for(node *p = Adj[u]; p; p = p->nxt) {
if(sz[p->v]) continue;
fa[p->v] = u; dep[p->v] = dep[u]+1;
dfs1(p->v);
if(sz[hsn[u]] < sz[p->v]) hsn[u] = p->v;
sz[u] += sz[p->v];
}
}
void dfs2(int u, int tp) {
pos[u] = ++ dcnt; htp[u] = tp; rev[dcnt] = u; //pos记录了u的dfs序。
if(!hsn[u]) return;
dfs2(hsn[u], tp);
for(node *p = Adj[u]; p; p = p->nxt)
if(htp[p->v]) continue;
else if(p->v != hsn[u]) dfs2(p->v, p->v);
}
/******************************************/
int main() {
GET(n); int u, v;
for(int i = 1; i < n; ++ i) {
GET(u); GET(v);
Add(u, v); Add(v, u);
}
for(int i = 1; i <= n; ++ i) GET(a[i]);
dfs1(1); dfs2(1, 1);
Build(1, 1, n);
GET(q); char s[6];
for(int i = 1; i <= q; ++ i) {
scanf("%s", s); GET(u); GET(v);
if(s[0] == 'C') { a[u] = v; Ins(pos[u], v); }
else {
int t = LCA(u, v);
if(s[1] == 'M') printf("%d\n", Max(Mx(u, t), Mx(v, t)));
else printf("%d\n", sum(u, t) + sum(v, t) - a[t]);
}
}
return 0;
}