Codeforces-375D Tree and Queries(树上dsu)

题意

给定一棵 n n 个节点的树,每个节点上有一个颜色。有 m m 个询问,每次询问 u u 对应的子树中,有多少种颜色至少出现 k k 次。
1n,m105 1 ≤ n , m ≤ 10 5

思路

树上 dsu d s u 的入门题(即树上启发式合并)。首先,暴力比较好想,对于每一个询问的 u u ,在 u u 的子树中统计每种颜色出现的次数,得到出现次数超过 k k 次的颜色个数。复杂度 O(mn) O ( m n )
考虑优化。不难发现有些节点的值会在多个祖先中发挥作用,对于每一个子树查询都重新 dfs d f s 一遍会重复操作。于是,对于每一个子树的查询,最好能继承更多儿子中的状态,于是有了如下一段代码:

void dsu(int u,int f)
{
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v!=f&&v!=son[u])dsu(v,u),update(L[v],R[v],-1);  //递归轻儿子并删除信息
    }
    if(son[u])dsu(son[u],u);  //递归重儿子并保留信息
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v!=f&&v!=son[u])update(L[v],R[v],1);  //加入轻儿子的信息
    }
    update(L[u],L[u],1);  //加入子树根节点信息
    FOR(i,0,(int)ask[u].size()-1)Out[ask[u][i].id]=...  //回答询问
}

上面就是树上启发式合并的精髓部分,算法的流程如下:
对于某一个节点:
1.递归轻儿子对应的子树,再将其信息删去;
2.递归重儿子对应的子树;
3.加上轻儿子的信息;
4.加入子树根节点的信息;
5.回答这个子树的询问;
6.返回。
先关注这个算法的复杂度主要堆积在哪里,不难发现, dsu d s u 这个递归函数只是把 n n 各节点的扫了一遍,复杂度无疑只是 O(n) O ( n ) 。而 update u p d a t e 函数,由于要对节点的 dfs d f s 序给值,成为了复杂度堆积的地方。 dsu d s u 回溯上来的时候,这个子树的信息已经被加入了,而我们最后算的是重儿子,所以信息保留,只用再加上其他儿子和根的信息。由此可以看出,对于一个节点 v v ,如果的信息需要被加入 u u 的子树,就说明它一定是 u u 的轻儿子,而这也说明一定存在大小不亚于它的重儿子,使得合并后的 u u 子树大小至少扩大为原子树 u u 的两倍。每个节点的合并次数不超过 logn log ⁡ n ,复杂度就保证在了 O(nlogn) O ( n log ⁡ n ) 以下。
这就是神奇的树上 dsu d s u ,事实上 dsu d s u 就是一种聪明的暴力,把小集合并到大集合,使得每个元素在合并后,所在集合的大小至少变为原来两倍,每个元素合并次数均不超过 logn log ⁡ n ,复杂度仍保持在 O(nlogn) O ( n log ⁡ n ) 以下,树上的也是同样的道理,不一样的是它规定了合并的方式,通过选出重儿子使复杂度变为 dsu d s u 的复杂度。
有了模板之后,只用考虑怎么更新了。比较好想的是开一个记录每种颜色出现次数的数组 cnt c n t ,再利用树状数组,维护出现次数为某个次数的颜色种数,每次查询 [k,n] [ k , n ] 这个区间。但事实上不需要,只需用一个普通的数组 num n u m ,表示至少出现某个次数的颜色种数。那么只用在 col c o l 颜色再次出现时,保留原来在 cnt c n t 里的贡献 cnt[col] c n t [ c o l ] ,在 num[cnt[col]+1] n u m [ c n t [ c o l ] + 1 ] 加一即可,删除同理。

代码

#include
#include
#include
#include
#include
#include
#include
#define FOR(i,x,y) for(int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
    int head[maxn],to[maxm],nxt[maxm],tot;
    void clear(){memset(head,-1,sizeof(head));tot=0;}
    void add(int u,int v){to[++tot]=v,nxt[tot]=head[u],head[u]=tot;}
    #define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list1>G;
struct Query{int id,k;};
vectorask[N];
int L[N],R[N],ori[N],sz[N],son[N],cnt[N],num[N],col[N],res[N];
int ord,n,m;

void dfs(int u,int f)
{
    L[u]=++ord,ori[ord]=u,sz[u]=1,son[u]=0;
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v==f)continue;
        dfs(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]])son[u]=v;
    }
    R[u]=ord;
}
void add(int L,int R)
{
    FOR(i,L,R)
        num[++cnt[col[ori[i]]]]++;
}
void del(int L,int R)
{
    FOR(i,L,R)
        --num[cnt[col[ori[i]]]--];
}
void dsu(int u,int f)
{
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v!=f&&v!=son[u])dsu(v,u),del(L[v],R[v]);
    }
    if(son[u])dsu(son[u],u);
    EOR(i,G,u)
    {
        int v=G.to[i];
        if(v!=f&&v!=son[u])add(L[v],R[v]);
    }
    add(L[u],L[u]);
    FOR(i,0,(int)ask[u].size()-1)
        res[ask[u][i].id]=num[ask[u][i].k];
}
void clear()
{
    memset(cnt,0,sizeof(cnt));
    memset(num,0,sizeof(num));
    G.clear();ord=0;
    FOR(i,1,n)ask[i].clear();
}

int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        clear();
        FOR(i,1,n)scanf("%d",&col[i]);
        FOR(i,1,n-1)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            G.add(u,v);G.add(v,u);
        }
        dfs(1,0);
        FOR(i,1,m)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            ask[u].push_back((Query){i,v});
        }
        dsu(1,0);
        FOR(i,1,m)printf("%d\n",res[i]);
    }
    return 0;
}

你可能感兴趣的:(题目)