【HDU】5574 Colorful Tree【子树染色,询问子树颜色数——线段树+bit+lca+set】

题目链接:【HDU】5574 Colorful Tree

题目大意:对一个子树染色,询问一个子树的颜色数。
题目分析: set 维护每种颜色所在的 dfs 序区间,修改均摊 nlogn

#include 
using namespace std ;

typedef long long LL ;
typedef pair < int , int > pii ;
#define clr( a , x ) memset ( a , x , sizeof a )

const int MAXN = 100005 ;
const int MAXE = 200005 ;

struct Edge {
    int v , n ;
    Edge () {}
    Edge ( int v , int n ) : v ( v ) , n ( n ) {}
} ;


Edge E[MAXE] ;
int H[MAXN] , cntE ;

int dep[MAXN] ;
int pre[MAXN] ;
int top[MAXN] ;
int son[MAXN] ;
int siz[MAXN] ;
int tree_idx ;
int in[MAXN] ;
int ou[MAXN] ;
int dfs_idx ;
int n , q ;

set < pii > s , col[MAXN] ;
set < pii > :: iterator it , it1 ;
int c[MAXN] ;
int T[MAXN] ;
int setv[MAXN << 2] ;

void addedge ( int u , int v ) {
    E[cntE] = Edge ( v , H[u] ) ;
    H[u] = cntE ++ ;
}

void dfs ( int u ) {
    in[u] = ++ dfs_idx ;
    siz[u] = 1 ;
    son[u] = 0 ;
    for ( int i = H[u] ; ~i ; i = E[i].n ) {
        int v = E[i].v ;
        if ( v == pre[u] ) continue ;
        pre[v] = u ;
        dep[v] = dep[u] + 1 ;
        dfs ( v ) ;
        siz[u] += siz[v] ;
        if ( siz[son[u]] < siz[v] ) son[u] = v ;
    }
    ou[u] = dfs_idx ;
}

void rebuild ( int u , int top_element ) {
    top[u] = top_element ;
    if ( son[u] ) rebuild ( son[u] , top_element ) ;
    for ( int i = H[u] ; ~i ; i = E[i].n ) {
        int v = E[i].v ;
        if ( v != pre[u] && v != son[u] ) rebuild ( v , v ) ;
    }
}

int lca ( int x , int y ) {
    while ( top[x] != top[y] ) dep[top[x]] > dep[top[y]] ? x = pre[top[x]] : y = pre[top[y]] ;
    return dep[x] < dep[y] ? x : y ;
}

void build ( int o , int l , int r ) {
    setv[o] = 0 ;
    if ( l == r ) {
        setv[o] = c[in[l]] ;
        return ;
    }
    int m = l + r >> 1 ;
    build ( o << 1 , l , m ) ;
    build ( o << 1 | 1 , m + 1 , r ) ;
}

void down ( int o ) {
    if ( setv[o] ) {
        setv[o << 1] = setv[o << 1 | 1] = setv[o] ;
        setv[o] = 0 ;
    }
}

void update ( int L , int R , int v , int o , int l , int r ) {
    if ( L <= l && r <= R ) {
        setv[o] = v ;
        return ;
    }
    down ( o ) ;
    int m = l + r >> 1 ;
    if ( L <= m ) update ( L , R , v , o << 1 , l , m ) ;
    if ( m <  R ) update ( L , R , v , o << 1 | 1 , m + 1 , r ) ;
}

int query ( int x , int o , int l , int r ) {
    while ( l < r ) {
        down ( o ) ;
        int m = l + r >> 1 ;
        if ( x <= m ) o = o << 1 , r = m ;
        else o = o << 1 | 1 , l = m + 1 ;
    }
    return setv[o] ;
}

void add ( int x , int v ) {
    for ( int i = x ; i <= n ; i += i & -i ) T[i] += v ;
}

int sum ( int x , int ans = 0 ) {
    for ( int i = x ; i >= 1 ; i -= i & -i ) ans += T[i] ;
    return ans ;
}

int get ( int x ) {
    int x1 = 0 , x2 = 0 ;
    if ( it != col[c[x]].begin () ) {
        it1 = it ;
        -- it1 ;
        x1 = lca ( it1->second , x ) ;
    }
    it1 = it ;
    ++ it1 ;
    if ( it1 != col[c[x]].end () ) x2 = lca ( it1->second , x ) ;
    if ( dep[x1] > dep[x2] ) return x1 ;
    return x2 ;
}

void solve () {
    int op , x , y ;
    scanf ( "%d" , &n ) ;
    cntE = dfs_idx = tree_idx = 0 ;
    s.clear () ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        H[i] = -1 ;
        col[i].clear () ;
        T[i] = 0 ;
    }
    for ( int i = 1 ; i < n ; ++ i ) {
        scanf ( "%d%d" , &x , &y ) ;
        addedge ( x , y ) ;
        addedge ( y , x ) ;
    }
    dep[1] = 1 ;
    dfs ( 1 ) ;
    rebuild ( 1 , 1 ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        scanf ( "%d" , &c[i] ) ;
        col[c[i]].insert ( make_pair ( in[i] , i ) ) ;
        s.insert ( make_pair ( in[i] , i ) ) ;
    }
    build ( 1 , 1 , n ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        int pre = 0 ;
        for ( it = col[i].begin () ; it != col[i].end () ; ++ it ) {
            add ( it->first , 1 ) ;
            if ( pre ) add ( in[lca ( pre , it->second )] , -1 ) ;
            pre = it->second ;
        }
    }
    scanf ( "%d" , &q ) ;
    for ( int i = 1 ; i <= q ; ++ i ) {
        scanf ( "%d%d" , &op , &x ) ;
        pii tmp ( in[x] , x ) ;
        if ( op == 0 ) {
            scanf ( "%d" , &y ) ;
            while ( 1 ) {
                it = s.lower_bound ( tmp ) ;
                if ( it == s.end () || it->first > ou[x] ) break ;
                int u = it->second ;
                it = col[c[u]].find ( *it ) ;
                int x1 = get ( u ) ;
                if ( x1 ) add ( in[x1] , 1 ) ;
                add ( in[u] , -1 ) ;
                pii tmp1 = *it ;
                s.erase ( tmp1 ) ;
                col[c[u]].erase ( tmp1 ) ;
                c[u] = y ;
            }
            c[x] = y ;
            s.insert ( tmp ) ;
            col[y].insert ( tmp ) ;
            it = col[y].find ( tmp ) ;
            int x1 = get ( x ) ;
            if ( x1 ) add ( in[x1] , -1 ) ;
            add ( in[x] , 1 ) ;
            update ( in[x] , ou[x] , y , 1 , 1 , n ) ;
        } else {
            int ans = sum ( ou[x] ) - sum ( in[x] - 1 ) ;
            it = s.lower_bound ( tmp ) ;
            if ( it == s.end () ) ++ ans ;
            else if ( it->second != x ) {
                int color = query ( in[x] , 1 , 1 , n ) ;
                it = col[color].lower_bound ( tmp ) ;
                if ( it == col[color].end () || it->first > ou[x] ) ++ ans ;
            }
            printf ( "%d\n" , ans ) ;
        }
    }
}

int main () {
    int T ;
    scanf ( "%d" , &T ) ;
    for ( int i = 1 ; i <= T ; ++ i ) {
        printf ( "Case #%d:\n" , i ) ;
        solve () ;
    }
    return 0 ;
}

压缩后代码:

#include 
using namespace std;
typedef long long LL;
typedef pair<int ,int >pii;
#define clr(a,x) memset(a,x,sizeof a)
const int MAXN=100005;
const int MAXE=200005;
struct Edge{
    int v,n;
    Edge(){}
    Edge(int v,int n):v(v),n(n){}
}E[MAXE];
int H[MAXN],cntE;
int dep[MAXN],pre[MAXN],top[MAXN],son[MAXN],siz[MAXN],tree_idx;
int in[MAXN],ou[MAXN],dfs_idx;
sets,col[MAXN];
set::iterator it,it1;
int c[MAXN],T[MAXN],setv[MAXN<<2],n,q;
void addedge(int u,int v){E[cntE]=Edge(v,H[u]);H[u]=cntE++;}
void dfs(int u){
    in[u]=++dfs_idx;
    siz[u]=1;
    son[u]=0;
    for(int i=H[u];~i;i=E[i].n){
        int v=E[i].v;
        if(v==pre[u])continue;
        pre[v]=u;
        dep[v]=dep[u]+1;
        dfs(v);
        siz[u]+=siz[v];
        if(siz[son[u]]void rebuild(int u,int top_element){
    top[u]=top_element;
    if(son[u])rebuild(son[u],top_element);
    for(int i=H[u];~i;i=E[i].n)if(E[i].v!=pre[u]&&E[i].v!=son[u])rebuild(E[i].v,E[i].v);
}
int lca(int x,int y){
    while(top[x]!=top[y])dep[top[x]]>dep[top[y]]?x=pre[top[x]]:y=pre[top[y]];
    return dep[x]void build(int o,int l,int r){
    setv[o]=0;
    if(l==r)setv[o]=c[in[l]];
    else {
        int m=l+r>>1;
        build(o<<1,l,m);
        build(o<<1|1,m+1,r);
    }
}
void down(int o){if(setv[o])setv[o<<1]=setv[o<<1|1]=setv[o],setv[o]=0;}
void update(int L,int R,int v,int o,int l,int r){
    if(L<=l&&r<=R)setv[o]=v;
    else {
        down(o);
        int m=l+r>>1;
        if(L<=m)update(L,R,v,o<<1,l,m);
        if(m1|1,m+1,r);
    }
}
int query(int x,int o,int l,int r){
    for(int m=l+r>>1;l1,r=m):(o=o<<1|1,l=m+1),m=l+r>>1)down(o);
    return setv[o];
}
void add(int x,int v){for(int i=x;i<=n;i+=i&-i)T[i]+=v;}
int sum(int x,int ans=0){
    for(int i=x;i>=1;i-=i&-i)ans+=T[i];
    return ans;
}
int get(int x,int x1=0,int x2=0){
    if(it!=col[c[x]].begin())x1=lca((--(it1=it))->second,x);
    if((++(it1=it))!=col[c[x]].end())x2=lca(it1->second,x);
    if(dep[x1]>dep[x2])return x1;
    return x2;
}
void solve(){
    int op,x,y;
    scanf("%d",&n);
    cntE=dfs_idx=tree_idx=0;
    s.clear();
    for(int i=1;i<=n;++i)H[i]=-1,col[i].clear(),T[i]=0;
    for(int i=1;iscanf("%d%d",&x,&y);
    dfs(dep[1]=1),rebuild(1,1);
    for(int i=1;i<=n;++i){
        scanf("%d",&c[i]);
        col[c[i]].insert(make_pair(in[i],i));
        s.insert(make_pair(in[i],i));
    }
    build(1,1,n);
    for(int i=1,pre=0;i<=n;++i,pre=0){
        for(it=col[i].begin();it!=col[i].end();++it){
            add(it->first,1);
            if(pre)add(in[lca(pre,it->second)],-1);
            pre=it->second;
        }
    }
    scanf("%d",&q);
    for(int i=1;i<=q;++i){
        scanf("%d%d",&op,&x);
        pii tmp(in[x],x);
        if(op==0){
            scanf("%d",&y);
            while(1){
                it=s.lower_bound(tmp);
                if(it==s.end()||it->first>ou[x])break;
                int u=it->second;
                it=col[c[u]].find(*it);
                int x1=get(u);
                if(x1)add(in[x1],1);
                add(in[u],-1);
                pii tmp1=*it;
                s.erase(tmp1);
                col[c[u]].erase(tmp1);
                c[u]=y;
            }
            c[x]=y;
            s.insert(tmp);
            col[y].insert(tmp);
            it=col[y].find(tmp);
            int x1=get(x);
            if(x1)add(in[x1],-1);
            add(in[x],1);
            update(in[x],ou[x],y,1,1,n);
        }else{
            int ans=sum(ou[x])-sum(in[x]-1);
            it=s.lower_bound(tmp);
            if(it==s.end())++ans;
            else if(it->second!=x){
                int color=query(in[x],1,1,n);
                it=col[color].lower_bound(tmp);
                if(it==col[color].end()||it->first>ou[x])++ans;
            }
            printf("%d\n",ans);
        }
    }
}
int main(){
    int T;
    scanf("%d",&T);
    for(int i=1;i<=T;++i){
        printf("Case #%d:\n",i);
        solve();
    }
    return 0;
}

你可能感兴趣的:(线段树,树状数组,最近公共祖先【LCA】,set)