P2486
很经典的题~
思路: 线段树染色+"熟练"剖分(某些出题人总是喜欢把序列上的题加个树链剖分搞到树上去)
先想一想序列上怎么做吧
线段树是个好东西
每个节点维护三个信息: ls: 左端点的颜色 rs: 右端点的颜色 cnt: [l, r] 中共有几个颜色段
合并?
fa.cnt = son1.cnt + son2.cnt - [son1.rs == son2.ls]
fa.ls = son1.ls , fa.rs = son2.rs
爹的左端点颜色就是左儿子的左端点颜色, 右端点亦然
如果左儿子与右儿子相接的颜色相同, 那么等于左儿子块数加右儿子块数-1(中间两个块会合成一个)
否则直接加就行啦
修改时要打标记 记录有没有被覆盖
回到树上问题时要特别注意的是询问
因为询问时有swap的操作, 将k记录x,y的顺序, 即相当于(x, y) 还是(y, x)
如果是(y, x), 最后还要反回来才能进行合并
在跳重链时, 总是将链接在它的左边, 最后将a左右儿子反一下再与b合并即可
#include
#include
#include
#define ll long long
using namespace std;
const int N = 105000*4;
int fa[N], id[N], siz[N];
int num, dep[N], son[N];
int w[N], wt[N], Top[N];
int h[N], ne[N], to[N];
int tot;
inline void add(int x,int y) {
ne[++tot] = h[x], h[x] = tot;
to[tot] = y;
}
void dfs1(int x,int f) {
fa[x] = f;
siz[x] = 1, dep[x] = dep[f] + 1;
for (int i = h[x]; i ;i = ne[i]) {
int y = to[i];
if (y == f) continue;
dfs1(y, x);
siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2(int x,int topf) {
id[x] = ++num; wt[num] = w[x];
Top[x] = topf;
if (!son[x]) return;
dfs2(son[x], topf);
for (int i = h[x]; i; i = ne[i]) {
int y = to[i];
if (y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
int n, m;
int L[N], R[N], cnt[N], ls[N], rs[N];
int tag[N];
#define p1 p << 1
#define p2 p << 1 | 1
struct node{
int cnt, ls, rs;
};
void update(node &fa,node i,node j) {
fa.cnt = i.cnt + j.cnt - (i.rs == j.ls);
fa.ls = i.ls, fa.rs = j.rs;
}
void build(int l,int r,int p) {
L[p] = l, R[p] = r;
if (l == r) {
cnt[p] = 1, ls[p] = rs[p] = wt[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, p1);
build(mid+1, r, p2);
cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]);
ls[p] = ls[p1], rs[p] = rs[p2];
}
void spread(int p) {
if (tag[p]) {
cnt[p1] = cnt[p2] = 1;
tag[p1] = tag[p2] = tag[p];
ls[p1] = ls[p2] = rs[p1] = rs[p2] = tag[p];
tag[p] = 0;
}
}
void change(int l,int r,int p,int c) {
if (L[p] >= l && R[p] <= r) {
cnt[p] = 1, tag[p] = c;
ls[p] = rs[p] = c;
return;
}
spread(p);
if (R[p1] >= l) change(l, r, p1, c);
if (L[p2] <= r) change(l, r, p2, c);
cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]);
ls[p] = ls[p1], rs[p] = rs[p2];
}
node ask(int l,int r,int p) {
if (L[p] >= l && R[p] <= r) return (node){cnt[p], ls[p], rs[p]};
spread(p);
node i;
if (R[p1] < l) return ask(l, r, p2);
if (L[p2] > r) return ask(l, r, p1);
update(i, ask(l, r, p1), ask(l, r, p2));
return i;
}
void change_e(int x,int y,int c) {
while (Top[x] != Top[y]) {
if (dep[Top[x]] < dep[Top[y]]) swap(x, y);
change(id[Top[x]], id[x], 1, c);
x = fa[Top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
change(id[y], id[x], 1, c);
}
int sum(int x,int y) {
node ans, a, b;
ans = a = b = (node){0,0,0};
int k = 1;
while (Top[x] != Top[y]) {
if (dep[Top[x]] < dep[Top[y]]) swap(x, y), swap(a, b), k ^= 1;
if (a.cnt == 0) a = ask(id[Top[x]], id[x], 1);
else update(a, ask(id[Top[x]], id[x], 1), a);
x = fa[Top[x]];
}
if (dep[x] < dep[y]) swap(x, y), swap(a, b);
if (a.cnt == 0) a = ask(id[y], id[x], 1);
else update(a, ask(id[y], id[x], 1), a);
if (b.cnt == 0) return a.cnt;
if (a.cnt == 0) return b.cnt;
if (!k) swap(a, b); // 将a, b恢复正常顺序
swap(a.ls, a.rs); //将a左右儿子换位
update(ans, a, b);
return ans.cnt;
}
char s[5];
ll a, b, c;
template
void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (;!isdigit(c); c = getchar()) if (c == '-') f = -1;
for (;isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
x *= f;
}
int main() {
read(n), read(m);
for (int i = 1;i <= n; i++) read(w[i]);
for (int i = 1;i < n; i++) {
read(a), read(b);
add(a, b); add(b, a);
}
dfs1(1, 0);
dfs2(1, 1);
build(1, n, 1);
while (m--) {
scanf ("%s", s + 1);
if (s[1] == 'C') {
read(a), read(b), read(c);
change_e(a, b, c);
}
else {
read(a), read(b);
printf ("%d\n", sum(a, b));
}
}
return 0;
}