P2486 [SDOI2011]染色

P2486

很经典的题~

思路: 线段树染色+"熟练"剖分(某些出题人总是喜欢把序列上的题加个树链剖分搞到树上去)

先想一想序列上怎么做吧

线段树是个好东西

每个节点维护三个信息: ls: 左端点的颜色 rs: 右端点的颜色 cnt: [l, r] 中共有几个颜色段

合并?

fa.cnt = son1.cnt + son2.cnt - [son1.rs == son2.ls]

fa.ls = son1.ls , fa.rs = son2.rs

爹的左端点颜色就是左儿子的左端点颜色, 右端点亦然

如果左儿子与右儿子相接的颜色相同, 那么等于左儿子块数加右儿子块数-1(中间两个块会合成一个)

否则直接加就行啦

修改时要打标记 记录有没有被覆盖

回到树上问题时要特别注意的是询问

因为询问时有swap的操作, 将k记录x,y的顺序, 即相当于(x, y) 还是(y, x)

如果是(y, x), 最后还要反回来才能进行合并

在跳重链时, 总是将链接在它的左边, 最后将a左右儿子反一下再与b合并即可

#include
#include
#include
#define ll long long
using namespace std;
const int N = 105000*4;
int fa[N], id[N], siz[N];
int num, dep[N], son[N];
int w[N], wt[N], Top[N];
int h[N], ne[N], to[N];
int tot;
inline void add(int x,int y) {
    ne[++tot] = h[x], h[x] = tot;
    to[tot] = y;
}
void dfs1(int x,int f) {
    fa[x] = f;
    siz[x] = 1, dep[x] = dep[f] + 1;
    for (int i = h[x]; i ;i = ne[i]) {
        int y = to[i];
        if (y == f) continue;
        dfs1(y, x);
        siz[x] += siz[y];
        if (siz[y] > siz[son[x]]) son[x] = y;
    }
}
void dfs2(int x,int topf) {
    id[x] = ++num; wt[num] = w[x];
    Top[x] = topf;
    if (!son[x]) return;
    dfs2(son[x], topf);
    for (int i = h[x]; i; i = ne[i]) {
        int y = to[i];
        if (y == fa[x] || y == son[x]) continue;
        dfs2(y, y);
    }
}
int n, m;
int L[N], R[N], cnt[N], ls[N], rs[N];
int tag[N];
#define p1 p << 1
#define p2 p << 1 | 1

struct node{
    int cnt, ls, rs;
};
void update(node &fa,node i,node j) {
    fa.cnt = i.cnt + j.cnt - (i.rs == j.ls);
    fa.ls = i.ls, fa.rs = j.rs;
}

void build(int l,int r,int p) {
    L[p] = l, R[p] = r;
    if (l == r) {
        cnt[p] = 1, ls[p] = rs[p] = wt[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, p1);
    build(mid+1, r, p2);
    cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]);
    ls[p] = ls[p1], rs[p] = rs[p2];
}

void spread(int p) {
    if (tag[p]) {
        cnt[p1] = cnt[p2] = 1;
        tag[p1] = tag[p2] = tag[p];
        ls[p1] = ls[p2] = rs[p1] = rs[p2] = tag[p];
        tag[p] = 0;
    }
}

void change(int l,int r,int p,int c) {
    if (L[p] >= l && R[p] <= r) {
        cnt[p] = 1, tag[p] = c;
        ls[p] = rs[p] = c;
        return;
    }
    spread(p);
    if (R[p1] >= l) change(l, r, p1, c);
    if (L[p2] <= r) change(l, r, p2, c);
    cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]);
    ls[p] = ls[p1], rs[p] = rs[p2];
}

node ask(int l,int r,int p) {
    if (L[p] >= l && R[p] <= r) return (node){cnt[p], ls[p], rs[p]};
    spread(p);
    node i;
    if (R[p1] < l) return ask(l, r, p2);
    if (L[p2] > r) return ask(l, r, p1);
    update(i, ask(l, r, p1), ask(l, r, p2));
    return i;
}
    
void change_e(int x,int y,int c) {
    while (Top[x] != Top[y]) {
        if (dep[Top[x]] < dep[Top[y]]) swap(x, y);
        change(id[Top[x]], id[x], 1, c);
        x = fa[Top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y);
    change(id[y], id[x], 1, c);
}

int sum(int x,int y) {
    node ans, a, b;
    ans = a = b = (node){0,0,0};
    int k = 1;
    while (Top[x] != Top[y]) {
        if (dep[Top[x]] < dep[Top[y]]) swap(x, y), swap(a, b), k ^= 1;
        if (a.cnt == 0) a = ask(id[Top[x]], id[x], 1);
        else update(a, ask(id[Top[x]], id[x], 1), a);
        x = fa[Top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y), swap(a, b);
    if (a.cnt == 0) a = ask(id[y], id[x], 1);
    else update(a, ask(id[y], id[x], 1), a);
    if (b.cnt == 0) return a.cnt;
    if (a.cnt == 0) return b.cnt;
    if (!k) swap(a, b); // 将a, b恢复正常顺序
    swap(a.ls, a.rs); //将a左右儿子换位
    update(ans, a, b);
    return ans.cnt;
}
    
        
char s[5];
ll a, b, c;

template  
void read(T &x) {
    x = 0; int f = 1;
    char c = getchar();
    for (;!isdigit(c); c = getchar()) if (c == '-') f = -1;
    for (;isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
int main() {
    read(n), read(m);
    for (int i = 1;i <= n; i++) read(w[i]);
    for (int i = 1;i < n; i++) {
        read(a), read(b);
        add(a, b); add(b, a);
    }
    dfs1(1, 0); 
    dfs2(1, 1);
    build(1, n, 1);
    while (m--) {
        scanf ("%s", s + 1);
        if (s[1] == 'C') {
            read(a), read(b), read(c);
            change_e(a, b, c);
        }
        else {
            read(a), read(b);
            printf ("%d\n", sum(a, b));
        }
    }
    return 0;
}

你可能感兴趣的:(P2486 [SDOI2011]染色)