给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
干,真的很难。
树论的题代码一个比一个长,一个比一个难调。
wa了几发才看见原来0也算一种颜色,赶紧把lazy的初始值改成-1。
#include
using namespace std;
const int maxn = 1e5 + 5;
vector<int> G[maxn];
int n, m, in[maxn];
//子节点个数,深度
int sz[maxn], dep[maxn];
//重儿子,父节点
int ch[maxn], fa[maxn];
//重链开头,dfs序
int top[maxn], tid[maxn], tid2[maxn];
int tot;
void dfs1(int u, int f, int d) {
sz[u] = 1;
fa[u] = f;
dep[u] = d;
for(int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if(v == f) {
continue;
}
dfs1(v, u, d + 1);
sz[u] += sz[v];
if(ch[u] == -1 || sz[v] > sz[ch[u]]) {
ch[u] = v;
}
}
}
void dfs2(int u, int tp) {
top[u] = tp;
tid[u] = ++tot;
tid2[tot] = u;
if(ch[u] == -1) {
return;
}
dfs2(ch[u], tp);
for(int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if(v != ch[u] && v != fa[u]) {
dfs2(v, v);
}
}
}
struct SegmentTree {
int sum[maxn << 2], lazy[maxn << 2];
int left[maxn << 2], right[maxn << 2];
void pushUp(int i) {
left[i] = left[i << 1];
right[i] = right[i << 1 | 1];
sum[i] = sum[i << 1] + sum[i << 1 | 1] - (left[i << 1 | 1] == right[i << 1]);
}
void pushDown(int i) {
if(lazy[i] == -1) {
return;
}
left[i << 1] = right[i << 1] = lazy[i];
left[i << 1 | 1] = right[i << 1 | 1] = lazy[i];
lazy[i << 1] = lazy[i << 1 | 1] = lazy[i];
sum[i << 1] = sum[i << 1 | 1] = 1;
lazy[i] = -1;
}
void build(int i, int l, int r) {
lazy[i] = -1;
if(l == r) {
left[i] = right[i] = in[tid2[l]];
sum[i] = 1;
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
pushUp(i);
}
void update(int i, int l, int r, int L, int R, int cor) {
if(l >= L && r <= R) {
sum[i] = 1;
left[i] = right[i] = cor;
lazy[i] = cor;
return;
}
pushDown(i);
int mid = (l + r) >> 1;
if(L <= mid) {
update(i << 1, l, mid, L, R, cor);
}
if(R > mid) {
update(i << 1 | 1, mid + 1, r, L, R, cor);
}
pushUp(i);
}
int query(int i, int l, int r, int L, int R) {
if(l >= L && r <= R) {
return sum[i];
}
pushDown(i);
int mid = (l + r) >> 1;
if(R <= mid) {
return query(i << 1, l, mid, L, R);
}
if(L > mid) {
return query(i << 1 | 1, mid + 1, r, L, R);
}
return query(i << 1, l, mid, L, R) + query(i << 1 | 1, mid + 1, r, L, R) - (left[i << 1 | 1] == right[i << 1]);
}
int query(int i, int l, int r, int pos) {
if(l == r){
return left[i];
}
pushDown(i);
int mid = (l + r) >> 1;
if(pos <= mid) {
return query(i << 1, l, mid, pos);
}
return query(i << 1 | 1, mid + 1, r, pos);
}
int solve(int u, int v, int val) {
int f1 = top[u], f2 = top[v];
// 不在同一条链上
int ans = 0;
while(f1 != f2) {
if(dep[f1] < dep[f2]) {
swap(f1, f2);
swap(u, v);
}
if(val >= 0) {
update(1, 1, n, tid[f1], tid[u], val);
} else {
ans += query(1, 1, n, tid[f1], tid[u]);
ans -= query(1, 1, n, tid[fa[f1]]) == query(1, 1, n, tid[f1]);
}
u = fa[f1];
f1 = top[u];
}
if(dep[u] < dep[v]) {
swap(u, v);
}
if(val >= 0) {
update(1, 1, n, tid[v], tid[u], val);
return 0;
}
return ans + query(1, 1, n, tid[v], tid[u]);
}
} st;
int main() {
char op[3];
scanf("%d%d", &n, &m);
memset(ch, -1, sizeof(ch));
tot = 0;
for(int i = 1;i <= n; i++){
G[i].clear();
}
for(int i = 1; i <= n; i++) {
scanf("%d", &in[i]);
}
int u, v, val;
for(int i = 1; i <= n - 1; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1, 0, 0);
dfs2(1, 1);
st.build(1, 1, n);
while(m--) {
scanf("%s%d%d", op, &u, &v);
if(op[0] == 'C') {
scanf("%d", &val);
st.solve(u, v, val);
} else {
printf("%d\n", st.solve(u, v, -1));
}
}
return 0;
}