树链剖分 HYSBZ 2243 染色

题目很好理解

segtree不好维护,写了好久的

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std; 
const int M = 1e5+10; 
int head[M], cnt; 
struct Edge{
    int u, v, next; 
    void set(int _u, int _v){
        u = _u, v = _v; 
        next = head[u]; 
        head[u] = cnt++; 
    }
}edge[M << 1]; 
struct Node{
    int l, r, n, ll, rr, f; 
}node[M << 4]; 
int pre[M], son[M], siz[M], top[M], dep[M], pos[M], tot; 
int s[M], n, m; 

void build(int l, int r, int p){
    node[p].l = l, node[p].r = r; 
    node[p].f = 0, node[p].n = 0; 
    if(l == r)return ; 
    int mid = (l+r) >> 1; 
    build(l, mid, p << 1); 
    build(mid+1, r, p << 1|1); 
}
void update1(int k, int c, int p){
    if(node[p].l == node[p].r){
        node[p].ll = c; 
        node[p].rr = c; 
        node[p].n = 1; 
        return ; 
    }
    int mid = (node[p].l+node[p].r) >> 1; 
    update1(k, c, p << 1|(k>mid)); 
}
void update2(int l, int r, int p){
    if(l == r)return ; 
    int mid = (l+r) >> 1; 
    update2(l, mid, p << 1); 
    update2(mid+1, r, p << 1|1);
    node[p].ll = node[p << 1].ll; 
    node[p].rr = node[p << 1|1].rr; 
    node[p].n = node[p << 1].n+node[p << 1|1].n-(node[p << 1].rr == node[p << 1|1].ll); 
}
void dfs_1(int u, int f, int d){
    pre[u] = f, dep[u] = d; 
    siz[u] = 1, son[u] = 0; 
    for(int i = head[u]; ~i; i = edge[i].next){
        int v = edge[i].v; 
        if(v != f){
            dfs_1(v, u, d+1); 
            if(siz[son[u]]<siz[v])son[u] = v; 
            siz[u] += siz[v]; 
        }
    }
}
void dfs_2(int u, int tp){
    top[u] = tp, pos[u] = ++tot; 
    if(son[u])dfs_2(son[u], tp); 
    for(int i = head[u]; ~i; i = edge[i].next){
        int v = edge[i].v; 
        if(v != pre[u] && v != son[u])dfs_2(v, v); 
    } 
}
void pushdown(int p){
    node[p].f = 0; 
    node[p << 1].f = 1;
    node[p << 1].ll = node[p].ll; 
    node[p << 1].rr = node[p].rr; 
    node[p << 1].n = 1; 

    node[p << 1|1].f = 1;
    node[p << 1|1].ll = node[p].ll; 
    node[p << 1|1].rr = node[p].rr; 
    node[p << 1|1].n = 1; 
}
int query(int l, int r, int p){
    if(node[p].l == l && node[p].r == r){
        return node[p].n; 
    }
    if(node[p].f)pushdown(p); 
    int mid = (node[p].l+node[p].r) >> 1, ret = 0;  
    if(l <= mid){
        ret += query(l, min(r, mid), p << 1); 
    }
    if(r>mid){
        ret += query(max(l, mid+1), r, p << 1|1); 
    }
    return ret-(l <= mid && r>mid && node[p << 1].rr == node[p << 1|1].ll); 
}
int getcol(int k, int p){
    if(node[p].f){
        return node[p].ll;
    }
    if(node[p].l == node[p].r){
        return node[p].ll; 
    }
    int mid = (node[p].l+node[p].r) >> 1; 
    return getcol(k, p << 1|(k>mid)); 
}
int Qsum(int u, int v){
    int f1 = top[u], f2 = top[v];
    int ru = -1, rv = -1, sum = 0, tmp;  
    while(f1 != f2){
        if(dep[f1]>dep[f2]){
            sum += query(pos[f1], pos[u], 1); 
            sum -= (getcol(pos[u], 1) == ru); 
            ru = getcol(pos[f1], 1); 
            u = pre[f1], f1 = top[u]; 
        } 
        else{
            sum += query(pos[f2], pos[v], 1); 
            sum -= (getcol(pos[v], 1) == rv); 
            rv = getcol(pos[f2], 1); 
            v = pre[f2], f2 = top[v]; 
        }
    }
    if(dep[u]<dep[v]){
        tmp = query(pos[u], pos[v], 1);
    }
    else{
        tmp = query(pos[v], pos[u], 1); 
    }
    sum += tmp;
    sum -= (ru == getcol(pos[u], 1)); 
    sum -= (rv == getcol(pos[v], 1)); 
    return sum; 
}
void update(int l, int r, int p, int c){
    if(node[p].l == l && node[p].r == r){
        node[p].ll = c, node[p].rr = c; 
        node[p].n = 1, node[p].f = 1; 
        return ; 
    }
    if(node[p].f)pushdown(p); 
    int mid = (node[p].l+node[p].r) >> 1; 
    if(l <= mid){
        update(l, min(r, mid), p << 1, c); 
    }
    if(r>mid){
        update(max(l, mid+1), r, p << 1|1, c); 
    }
    node[p].ll = node[p << 1].ll; 
    node[p].rr = node[p << 1|1].rr; 
    node[p].n = node[p << 1].n+node[p << 1|1].n-(node[p << 1].rr == node[p << 1|1].ll); 
}
void change(int u, int v, int c){
    int f1 = top[u], f2 = top[v];
    while(f1 != f2){
        if(dep[f1]<dep[f2]){
            swap(f1, f2), swap(u, v); 
        }
        update(pos[f1], pos[u], 1, c); 
        u = pre[f1], f1 = top[u]; 
    }
    if(dep[u]<dep[v])swap(u, v);
    update(pos[v], pos[u], 1, c); 
}
int main(){
    while(cin >> n >> m){
        cnt = 0, tot = 0; 
        memset(head, -1, sizeof(head)); 
        for(int i = 1; i <= n; i++){
            scanf("%d", &s[i]); 
        }
        build(1, n, 1);
        for(int i = 1, u, v; i<n; i++){
            scanf("%d%d", &u, &v); 
            edge[cnt].set(u, v); 
            edge[cnt].set(v, u); 
        } 
        dfs_1((n+1)/2, 0, 1);
        dfs_2((n+1)/2, (n+1)/2);
        for(int i = 1; i <= n; i++){
            update1(pos[i], s[i], 1); 
        }
        update2(1, n, 1);
        while(m--){
            char op[5]; 
            scanf("%s", op); 
            if(op[0] == 'Q'){
                int u, v; 
                scanf("%d%d", &u, &v); 
                printf("%d\n", Qsum(u, v));
            }
            else{
                int u, v, c; 
                scanf("%d%d%d", &u, &v, &c); 
                change(u, v, c);
            }
        }
    }
    return 0; 
}

你可能感兴趣的:(数据结构)