2243: [SDOI2011]染色树链剖分

对于线段树的操作,维护左端值,维护右端值,维护种类数,更新的时候,如果左儿子的右端==右儿子的左端,种类数减一,剩下的就是细节了。

#include<iostream>

#include<cstdio>

#include<cstring>

#include<map>

#include<vector>

#include<stdlib.h>

using namespace std;

#define lson l,mid,rt<<1

#define rson mid+1,r,rt<<1|1

const int maxn = 222222;

int deep[maxn];

struct Node

{

    int next; int to;

}e[maxn * 2];



int len, z; int head[maxn];

int top[maxn], son[maxn], father[maxn], size[maxn], pos[maxn], val[maxn], vis[maxn];

int color[maxn << 2], sum[maxn << 2], lsum[maxn << 2], rsum[maxn << 2];





void add(int from, int to)

{

    e[len].to = to;

    e[len].next = head[from];

    head[from] = len++;

}





void init(int x)

{

    size[x] = 1; son[x] = 0;

    for (int i = head[x]; i != -1; i = e[i].next){

        int cc = e[i].to;

        if (cc == father[x]) continue;

        father[cc] = x; deep[cc] = deep[x] + 1;

        init(cc);

        size[x] += size[cc];

        if (size[son[x]] < size[cc]) son[x] = cc;

    }

}



void dfs(int x, int tp)

{

    pos[x] = ++z; vis[z] = x; top[x] = tp;

    if (son[x]) dfs(son[x], tp);

    for (int i = head[x]; i != -1; i = e[i].next){

        int cc = e[i].to;

        if (cc == son[x] || cc == father[x]) continue;

        dfs(cc, cc);

    }

}



void up(int rt)

{

    lsum[rt] = lsum[rt << 1]; rsum[rt] = rsum[rt << 1 | 1];

    if (rsum[rt << 1] == lsum[rt << 1 | 1]) sum[rt] = sum[rt << 1] + sum[rt << 1 | 1] - 1;

    else sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];

}



void build(int l, int r, int rt)

{

    color[rt] = -1;

    if (l == r){

        color[rt] = lsum[rt] = rsum[rt] = val[vis[l]];

        sum[rt] = 1;

        return;

    }

    int mid = (l + r) >> 1;

    build(lson);

    build(rson);

    up(rt);

}



void down(int rt)

{

    if (~color[rt]){

        sum[rt << 1] = sum[rt << 1 | 1] = 1;

        lsum[rt << 1] = color[rt]; rsum[rt << 1] = color[rt];

        color[rt << 1] = color[rt << 1 | 1] = color[rt];

        lsum[rt << 1 | 1] = color[rt]; rsum[rt << 1 | 1] = color[rt];

        color[rt] = -1;

    }

}





void update(int L, int R, int ans, int l, int r, int rt)

{

    if (L <= l&&r <= R){

        color[rt] = ans; sum[rt] = 1;

        lsum[rt] = ans; rsum[rt] = ans; return;

    }

    int mid = (l + r) >> 1;

    down(rt);

    if (L <= mid) update(L, R, ans, lson);

    if (R > mid) update(L, R, ans, rson);

    up(rt);

}



int ask(int L, int R, int l, int r, int rt)

{

    if (L <= l&&r <= R) return sum[rt];

    int ans = 0;

    int mid = (l + r) >> 1;

    down(rt);

    if (L <= mid) ans += ask(L, R, lson);

    if (R > mid) ans += ask(L, R, rson);

    if (L <= mid&&R > mid&&rsum[rt << 1] == lsum[rt << 1 | 1]) ans--;

    return ans;

}



void gao(int x, int y, int ans)

{

    int fx = top[x]; int fy = top[y];

    while (fx != fy){

        if (deep[fx] < deep[fy]){

            swap(x, y); swap(fx, fy);

        }

        update(pos[fx], pos[x], ans, 1, z, 1);

        x = father[fx]; fx = top[x];

    }

    if (deep[x]>deep[y]) swap(x, y);

    update(pos[x], pos[y], ans, 1, z, 1);

}



int ask1(int key, int l, int r, int rt)

{

    if (l == r) return color[rt];

    down(rt);

    int mid = (l + r) >> 1;

    if (key <= mid) return ask1(key, lson);

    else return ask1(key, rson);

}



int gao1(int x, int y)

{

    int ans = 0;

    int fx = top[x]; int fy = top[y];

    while (fx != fy){

        if (deep[fx] < deep[fy]){

            swap(x, y); swap(fx, fy);

        }

        ans += ask(pos[fx], pos[x], 1, z, 1);

        x = father[fx];//手戳 ,写成了father[x],看了老半天

        int t = ask1(pos[x], 1, z, 1);

        int t1 = ask1(pos[fx], 1, z, 1);

        if (t == t1) ans--;

        fx = top[x];



    }

    if (deep[x]>deep[y]) swap(x, y);

    ans += ask(pos[x], pos[y], 1, z, 1);//pos[x] 顺手写成了pos[son[x]],看了好久

    return ans;

}



int main()

{

    char str[100];

    int n, m;

    int a, b, c;

    while (cin >> n >> m){

        z = len = 0;

        memset(head, -1, sizeof(head));

        for (int i = 1; i <= n; i++){

            scanf("%d", &val[i]);

        }

        for (int i = 1; i <= n - 1; i++){

            scanf("%d%d", &a, &b);

            add(a, b); add(b, a);

        }

        deep[1] = 1;

        init(1); dfs(1, 1); build(1, z, 1);

        for (int i = 0; i < m; i++){

            scanf("%s", str);

            if (str[0] == 'C'){

                scanf("%d%d%d", &a, &b, &c);

                gao(a, b, c);

            }

            else{

                scanf("%d%d", &a, &b);

                int t = gao1(a, b);

                printf("%d\n", t);

                //printf("%d %d %d %d jijiji\n",ask1(pos[1],1,z,1),ask1(pos[2],1,z,1),ask1(pos[3],1,z,1),ask1(pos[5],1,z,1));

            }

        }

    }

    return 0;

}

 

你可能感兴趣的:(sd)