SDOI2011_染色

SDOI_染色

背景:很早就想学习树链剖分,趁着最近有点自由安排的时间去学习一下,发现有个很重要的前置知识——线段树。(其实不一定是线段树,但是线段树应该是最常见的),和同学吐槽说树剖的剖和分都很死板,主要还是看线段树的维护功底。但是也要知道剖分完的结果,不然就算线段树玩得飞起,也维护不了。看了网上很多博客,都是说一个geth,一个mark完成树链剖分,然后映射到线段树上,进行维护,其实这只是一个大体思想,还是建议自己手动模拟一下去加深理解。

前置知识:
1、重儿子:hs[u]=v,表示vu的重儿子。意思是vu的儿子中子树规模(包括自己)最大的。
轻儿子:除了重儿子的其他儿子。
2、重链:由重儿子组成的链。
轻链:除了重链的其他链。
3、顶端结点:重链的开头。
说是重轻分解,其实实质是把重链揪出来(即从轻链处砍断连接关系)连在一起拼凑成区间(同一条重链上结点编号映射到数据结构上连续),用数据结构维护,也就是说把树变成由重链组成的,只剩下重链,不考虑轻链,对于映射,同一条重链中浅结点编号小。
个人觉得这句话很通俗易懂了。~

关于树链剖分:
首先:要是一棵树。。然后有几种剖分:1、随便剖分,爱怎么编号怎么编号。2、启发式剖分(也就是常见的重轻分解)。显然!2比较科学,随便的东西肯定不稳定,就算是不随便的也不一定稳定。。(基数排序最后倒数组时的for downto...就比for to.稳定,而正for显然不是随便的东西,我不会证明也一直没想明白,还望看官指点)。我们可以很简单的运用两次dfs完成对一棵树的剖分(第一次:geth,第二次:mark)。第一次主要是得到深度、父亲、规模、重儿子;第二次则是将同一条重链上的结点编号在一起,对应到线段树上。(rank[]、sa[] 这两个数组和在后缀数组中一样,因为不懂后缀数组,所以用这两个提醒自己还是个弱者),以及记录重链顶端结点。
其次:其实可以说,树链剖分的题,暴力求解就是树上倍增(跑LCA,然后沿途更新),那么如何优化?显然LCA肯定要跑,有没有办法跑得更快?答案是肯定的,树链的剖分就是让LCA跑得更快。显然对于V(u,v)要么在一条重链上,要么不在一条重链上。如果在一条重链上,深度浅的就是LCA,如果不在呢?不妨定义u为深度更深的结点,那么倍增的思想告诉我们应该把u跳到和v一样浅,然后一起跳。然而轻重分解直接把u跳到其所在重链顶端(期间维护和求解该链上的答案),判断u,v在不在一条重链上(tp[u]==tp[v]?),然后不断进行这个过程直到u,v在同一重链后运用数据结构维护求解。那么我们又知道了同一条重链新编号连续,那么进行区间维护就很方便了。
最后:看各位的线段树功底了,反正笔者的线段树是很差的。。
(PS:第一次看到给20s的题,有点刺激)


Code:

#pragma comment(linkerr, "/STACK: 1024000000,1024000000")
#include 
#define pb push_back
#define mp make_pair
#define eb emplace_back
#define em emplace
#define pii pair
#define de(x) cout << #x << " = " << x << endl
#define clr(a,b) memset(a,b,sizeof(a))
#define INF (0x3f3f3f3f)
#define LINF ((long long)(0x3f3f3f3f3f3f3f3f))
#define F first
#define S second
#define lson rt<<1,l,m
#define rson rt<<1|1,m+1,r
using namespace std;

const int N = 1e5 + 15;
int n, m;

int d[N], fa[N], sz[N], hs[N];
int nw, sa[N], rk[N], tp[N];
struct Edge
{
    int v, nxt;
};
Edge e[N<<1];
int h[N], ect;
void init()
{
    ect = nw = 0;
    clr(h,-1);
}
void _add( int u, int v )
{
    e[ect].v = v;
    e[ect].nxt = h[u];
    h[u] = ect ++;
}

void geth( int u, int f, int de )
{
    fa[u] = f;
    sz[u] = 1;
    hs[u] = 0;
    d[u] = de;
    for ( int i = h[u]; i+1; i = e[i].nxt )
    {
        int v = e[i].v;
        if ( v == f ) continue;
        geth( v, u, de+1 );
        sz[u] += sz[v];
        if ( sz[v] > sz[hs[u]] ) hs[u] = v;
    }
}
void mark( int u, int tu )
{
    tp[u] = tu;
    sa[++nw] = u; rk[u] = nw;
    if ( !hs[u] ) return ;
    mark( hs[u], tu );
    for ( int i = h[u]; i+1; i = e[i].nxt )
    {
        int v = e[i].v;
        if ( v != fa[u] && v != hs[u] ) mark(v,v);
    }
}

struct T
{
    int sm, lazy, lc, rc;
};
T t[N<<2];
int A[N];
int nwlc, nwrc;

void pushup( int rt )
{
    t[rt].sm = t[rt<<1].sm + t[rt<<1|1].sm;
    t[rt].lc = t[rt<<1].lc;
    t[rt].rc = t[rt<<1|1].rc;
    if ( t[rt<<1].rc == t[rt<<1|1].lc )
        t[rt].sm --;
}
void pushdown( int rt, int l, int r )
{
    if ( t[rt].lazy )
    {
        t[rt].lazy = 0;
        t[rt<<1].lazy = t[rt<<1|1].lazy = 1;
        t[rt<<1].sm = t[rt<<1|1].sm = 1;
        t[rt<<1].lc = t[rt<<1].rc = t[rt].lc;
        t[rt<<1|1].lc = t[rt<<1|1].rc = t[rt].rc;
    }
}

void build( int rt, int l, int r )
{
    t[rt].lazy = 0;
    if ( l == r )
    {
        t[rt].sm = 1;
        t[rt].lc = t[rt].rc = A[sa[l]];
        return ;
    }
    int m = (l+r) >> 1;
    build(lson); build(rson); pushup(rt);
}

void update( int L, int R, int c, int rt, int l, int r )
{
    if ( L <= l && r <= R )
    {
        t[rt].lc = t[rt].rc = c;
        t[rt].sm = t[rt].lazy = 1;
        return ;
    }
    int m = (l+r) >> 1;
    pushdown(rt,l,r);
    if ( L <= m ) update( L, R, c, lson );
    if ( R >  m ) update( L, R, c, rson );
    pushup(rt);
}

int query( int L, int R, int rt, int l, int r )
{
    if ( L == l ) nwlc = t[rt].lc;
    if ( R == r ) nwrc = t[rt].rc;
    if ( L <= l && r <= R )
        return t[rt].sm;
    int m = (l+r) >> 1, res = 0, lft = 0;
    
    pushdown(rt,l,r);
    if ( L <= m )
    {
        lft = 1;
        res += query( L, R, lson );
    }
    if ( R >  m )
    {
        res += query( L, R, rson );
        if ( lft && t[rt<<1].rc == t[rt<<1|1].lc ) res --;
    }
    pushup(rt);
    return res;
}

int getsum( int u, int v )
{
    int lstulc, lstvlc;
    lstulc = lstvlc = -1;
    int res = 0;
    int x = tp[u], y = tp[v];
    while ( x != y )
    {
        if ( d[x] < d[y] ) swap(x,y), swap(u,v), swap(lstulc,lstvlc);
        res += query( rk[x], rk[u], 1,1,n );
        if ( nwrc == lstulc ) res --;
        lstulc = nwlc;
        u = fa[x]; x = tp[u];
    }
    if ( d[u] > d[v] ) swap(u,v), swap( lstulc, lstvlc );
    res += query( rk[u], rk[v], 1,1,n );
    if ( nwlc == lstulc ) res --;
    if ( nwrc == lstvlc ) res --;
    return res;
}

void change( int u, int v, int c )
{
    int x = tp[u], y = tp[v];
    while ( x != y )
    {
        if ( d[x] < d[y] ) swap(x,y), swap(u,v);
        update( rk[x], rk[u], c, 1,1,n );
        u = fa[x]; x = tp[u];
    }
    if ( d[u] > d[v] ) swap( u, v );
    update( rk[u], rk[v], c, 1,1,n );
}

int main()
{
    init();
    scanf("%d%d", &n, &m);
    for ( int i = 1; i <= n; i ++ )
        scanf("%d", &A[i]);
    for ( int i = 1, u, v; i < n; i ++ )
    {
        scanf("%d%d", &u, &v);
        _add(u,v); _add(v,u);
    }
    geth(1,0,1);
    mark(1,1);
    build(1,1,n);
    
    while ( m -- )
    {
        char s[2];
        int u, v, c;
        scanf("%s %d%d", s, &u, &v);
        if ( s[0] == 'C' )
        {
            scanf("%d", &c);
            change( u, v, c );
        }
        else
            printf("%d\n", getsum(u,v));
    }
    return 0;
}

你可能感兴趣的:(SDOI2011_染色)