bzoj2243: [SDOI2011]染色

题目

bzoj2243

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

题解

树链剖分,然后用线段数维护每条链上的信息。
线段树中记录lc,rc,data分别表示线段左端点颜色,右端点颜色,与颜色段数。向上更新父亲时用:data[k]=data[k*2]+data[k*2+1]-(rc[k*2]==lc[k*2+1)。
在询问两点间路径时优先跳top的深度深的直到两点top相同。同时还要记录x或y已更新线段的左端点颜色tx与ty。询问另一段区间后判断rc是否等于tx(或ty)若相等则答案减一。

然而自己一开始写的时候TLE,最后发现是我SB了,dfs时忘了更新子树大小。结果重链全是乱的。改完之后4000多毫秒就过了。T_T

下面贴代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

#define maxn 100100
struct node{
    int l,r,lc,rc,tag,d;
}t[maxn*3];
struct edge{
    int x,next;
}e[maxn*2];
char s;
int c[maxn],first[maxn],n,m,x,y,tot,dfn,cc[maxn],z;
int son[maxn],top[maxn],size[maxn],dep[maxn],f[maxn],pos[maxn];

inline void add(int x,int y){
    e[++tot].next=first[x];
    e[tot].x=y;
    first[x]=tot;
}

void dfs1(int x){
    size[x]=1;
    for(int i=first[x];i;i=e[i].next)
    if(e[i].x!=f[x]){
        f[e[i].x]=x;
        dep[e[i].x]=dep[x]+1;
        dfs1(e[i].x);
        size[x]+=size[e[i].x];
        if(size[e[i].x]>size[son[x]])son[x]=e[i].x;
    }
}

void dfs2(int x,int y){
    pos[x]=++dfn; top[x]=y;
    c[dfn]=cc[x];
    if(son[x])dfs2(son[x],y);
    for(int i=first[x];i;i=e[i].next)
    if(e[i].x!=f[x]&&e[i].x!=son[x])dfs2(e[i].x,e[i].x);
}

void built(int l,int r,int k){
    t[k].l=l; t[k].r=r; t[k].tag=-1;
    if(l==r){
        t[k].lc=t[k].rc=c[l];
        t[k].d=1;
        return;
    }
    int mid=(l+r)>>1;
    built(l,mid,k*2);
    built(mid+1,r,k*2+1);
    t[k].d=t[k*2].d+t[k*2+1].d-(t[k*2].rc==t[k*2+1].lc);
    t[k].lc=t[k*2].lc; t[k].rc=t[k*2+1].rc;
}

inline void updata(int k){
    if(t[k].l==t[k].r)t[k].tag=-1;
    if(t[k].tag==-1)return;
    t[k*2].lc=t[k*2].rc=t[k*2].tag=t[k].tag;
    t[k*2+1].lc=t[k*2+1].rc=t[k*2+1].tag=t[k].tag;
    t[k*2+1].d=t[k*2].d=1;
    t[k].tag=-1;
}

void modify(int l,int r,int k,int x){
    if(l==t[k].l&&r==t[k].r){
        t[k].lc=t[k].rc=t[k].tag=x;
        t[k].d=1;
        return;
    }
    updata(k);
    int mid=(t[k].l+t[k].r)>>1;
    if(r<=mid)modify(l,r,k*2,x);
    else if(l>mid)modify(l,r,k*2+1,x);
    else modify(l,mid,k*2,x),modify(mid+1,r,k*2+1,x);
    t[k].d=t[k*2].d+t[k*2+1].d-(t[k*2].rc==t[k*2+1].lc);
    t[k].lc=t[k*2].lc; t[k].rc=t[k*2+1].rc;
}

int ask(int l,int r,int k,int &lc,int &rc){
    if(l==t[k].l&&r==t[k].r){
        lc=t[k].lc; rc=t[k].rc;
        return t[k].d;
    }
    updata(k);
    int mid=(t[k].l+t[k].r)>>1;
    if(r<=mid)return ask(l,r,k*2,lc,rc);
    else if(l>mid)return ask(l,r,k*2+1,lc,rc);
    else{
        int t1=0,t2=0,tmp=0;
        tmp=ask(l,mid,k*2,lc,t1)+ask(mid+1,r,k*2+1,t2,rc);
        if(t1==t2)tmp--;
        return tmp;
    }
}

void change(int x,int y,int k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        modify(pos[top[x]],pos[x],1,k);
        x=f[top[x]];
    }
    if(dep[x]<dep[y])swap(x,y);
    modify(pos[y],pos[x],1,k);
}

int solve(int x,int y){
    int lc,rc,tx=-1,ty=-1,ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y),swap(tx,ty);
        ans+=ask(pos[top[x]],pos[x],1,lc,rc);
        ans-=(rc==tx);
        tx=lc;
        x=f[top[x]];
    }
    if(dep[x]<dep[y])swap(x,y),swap(tx,ty);
    ans+=ask(pos[y],pos[x],1,lc,rc);
    ans-=(lc==ty)+(rc==tx);
    return ans;
}

inline void read(int &x){
    char c=getchar();
    for(;c<'0'||c>'9';c=getchar());
    for(x=0;'0'<=c&&c<='9';c=getchar())x=x*10+c-'0';
}

int main(){
    read(n); read(m);
    for(int i=1;i<=n;i++)scanf("%d",&cc[i]);
    for(int i=1;i<n;i++){
        read(x); read(y);
        add(x,y);
        add(y,x);
    }
    dfs1(1); dfs2(1,1);
    built(1,dfn,1);
    for(int i=1;i<=m;i++){
        for(s=getchar();s!='C'&&s!='Q';s=getchar());
        read(x); read(y);
        if(s=='C'){
            read(z);
            change(x,y,z);
        }else
        printf("%d\n",solve(x,y));
    }
    return 0;
}

你可能感兴趣的:(bzoj2243: [SDOI2011]染色)