[dsu] CSU1811: Tree Intersection

1811: Tree Intersection

题意:

给一棵树,每个节点有种颜色,考虑第 i 条边 (ai,bi) ,把这条边去掉后,分成两颗树,问两颗树节点的颜色集合的交集。

题解:

如果对每棵子树暴力,显然是 O(n2) 的,但可以发现子树的信息有一部分父亲节点可以复用,保留这一部分信息就成了 O(nlogn) ,也就是保存重儿子的信息。

//
//  main.cpp
//  1811
//
//  Created by 翅膀 on 16/9/5.
//  Copyright © 2016年 kg20006. All rights reserved.
//

#include 
#include 
#include 
#include 
#include 
using namespace std;
const int N = 1e5+5;
vector<int>G[N];
int val[N], ful[N], col[N];
int n;
int sz[N], mark[N];
void predfs(int rt, int f){
    sz[rt] = 1;
    for(int i = 0; i < G[rt].size(); ++i){
        int &v = G[rt][i];
        if(v == f) continue;
        predfs(v, rt);
        sz[rt] += sz[v];
    }
}
int ans[N], tmpans;
void add(int rt, int f){
    if(col[val[rt]] == ful[val[rt]]) tmpans++;
    col[val[rt]]--;
    if(col[val[rt]] == 0) tmpans--;
    for(int i = 0; i < G[rt].size(); ++i){
        int &v = G[rt][i];
        if(!mark[v] && v != f) add(v, rt);
    }
}
void clr(int rt, int f){
    if(col[val[rt]] == 0) tmpans++;
    col[val[rt]]++;
    if(col[val[rt]] == ful[val[rt]]) tmpans--;
    for(int i = 0; i < G[rt].size(); ++i){
        int &v = G[rt][i];
        if(!mark[v] && v != f) clr(v, rt);
    }
}
typedef pair<int,int> pii;
mapint>id;
void dfs(int rt, int f, int kp){
    int mx = -1, son = -1;
    for(int i = 0; i < G[rt].size(); ++i){
        int &v = G[rt][i];
        if(v == f) continue;
        if(mx < sz[v]) mx = sz[v], son = v;
    }
    if(son != -1) mark[son] = 1;
    for(int i = 0; i < G[rt].size(); ++i){
        int &v = G[rt][i];
        if(v == f || v == son) continue;
        dfs(v, rt, 0);
    }
    if(son != -1) dfs(son, rt, 1);
    add(rt, f);
    if(rt != 1){
        int _id;
        if(rt > f) _id = id[pii(f, rt)];
        else _id = id[pii(rt, f)];
        ans[_id] = tmpans;
    }
    if(son != -1) mark[son] = 0;
    if(!kp) clr(rt, f);
}
int main(int argc, const char * argv[]) {
    while(scanf("%d", &n) != EOF) {
        id.clear();
        memset(col, 0, sizeof(col));
        tmpans = 0;
        for(int i = 1; i <= n; ++i) G[i].clear(), scanf("%d", val+i), col[val[i]]++;
        for(int i = 1; i <= n; ++i) ful[i] = col[i];
        for(int u, v, i = 1; i < n; ++i) {
            scanf("%d%d", &u, &v);
            if(u > v) swap(u, v);
            id[pii(u,v)] = i;
            G[u].push_back(v);
            G[v].push_back(u);
        }
        predfs(1, 0);
        dfs(1, 0, 0);
        for(int i = 1; i < n; ++i) printf("%d\n", ans[i]);
    }
    return 0;
}

你可能感兴趣的:(ACM,题解)