给定一棵 n n 个节点的树,每个节点上有一个颜色。有 m m 个询问,每次询问 u u 对应的子树中,有多少种颜色至少出现 k k 次。
1≤n,m≤105 1 ≤ n , m ≤ 10 5
树上 dsu d s u 的入门题(即树上启发式合并)。首先,暴力比较好想,对于每一个询问的 u u ,在 u u 的子树中统计每种颜色出现的次数,得到出现次数超过 k k 次的颜色个数。复杂度 O(mn) O ( m n ) 。
考虑优化。不难发现有些节点的值会在多个祖先中发挥作用,对于每一个子树查询都重新 dfs d f s 一遍会重复操作。于是,对于每一个子树的查询,最好能继承更多儿子中的状态,于是有了如下一段代码:
void dsu(int u,int f)
{
EOR(i,G,u)
{
int v=G.to[i];
if(v!=f&&v!=son[u])dsu(v,u),update(L[v],R[v],-1); //递归轻儿子并删除信息
}
if(son[u])dsu(son[u],u); //递归重儿子并保留信息
EOR(i,G,u)
{
int v=G.to[i];
if(v!=f&&v!=son[u])update(L[v],R[v],1); //加入轻儿子的信息
}
update(L[u],L[u],1); //加入子树根节点信息
FOR(i,0,(int)ask[u].size()-1)Out[ask[u][i].id]=... //回答询问
}
上面就是树上启发式合并的精髓部分,算法的流程如下:
对于某一个节点:
1.递归轻儿子对应的子树,再将其信息删去;
2.递归重儿子对应的子树;
3.加上轻儿子的信息;
4.加入子树根节点的信息;
5.回答这个子树的询问;
6.返回。
先关注这个算法的复杂度主要堆积在哪里,不难发现, dsu d s u 这个递归函数只是把 n n 各节点的扫了一遍,复杂度无疑只是 O(n) O ( n ) 。而 update u p d a t e 函数,由于要对节点的 dfs d f s 序给值,成为了复杂度堆积的地方。 dsu d s u 回溯上来的时候,这个子树的信息已经被加入了,而我们最后算的是重儿子,所以信息保留,只用再加上其他儿子和根的信息。由此可以看出,对于一个节点 v v ,如果的信息需要被加入 u u 的子树,就说明它一定是 u u 的轻儿子,而这也说明一定存在大小不亚于它的重儿子,使得合并后的 u u 子树大小至少扩大为原子树 u u 的两倍。每个节点的合并次数不超过 logn log n ,复杂度就保证在了 O(nlogn) O ( n log n ) 以下。
这就是神奇的树上 dsu d s u ,事实上 dsu d s u 就是一种聪明的暴力,把小集合并到大集合,使得每个元素在合并后,所在集合的大小至少变为原来两倍,每个元素合并次数均不超过 logn log n ,复杂度仍保持在 O(nlogn) O ( n log n ) 以下,树上的也是同样的道理,不一样的是它规定了合并的方式,通过选出重儿子使复杂度变为 dsu d s u 的复杂度。
有了模板之后,只用考虑怎么更新了。比较好想的是开一个记录每种颜色出现次数的数组 cnt c n t ,再利用树状数组,维护出现次数为某个次数的颜色种数,每次查询 [k,n] [ k , n ] 这个区间。但事实上不需要,只需用一个普通的数组 num n u m ,表示至少出现某个次数的颜色种数。那么只用在 col c o l 颜色再次出现时,保留原来在 cnt c n t 里的贡献 cnt[col] c n t [ c o l ] ,在 num[cnt[col]+1] n u m [ c n t [ c o l ] + 1 ] 加一即可,删除同理。
#include
#include
#include
#include
#include
#include
#include
#define FOR(i,x,y) for(int i=(x);i<=(y);i++)
#define DOR(i,x,y) for(int i=(x);i>=(y);i--)
#define lowbit(x) ((x)&-(x))
#define N 100003
typedef long long LL;
using namespace std;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
void clear(){memset(head,-1,sizeof(head));tot=0;}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u],head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list1>G;
struct Query{int id,k;};
vector ask[N];
int L[N],R[N],ori[N],sz[N],son[N],cnt[N],num[N],col[N],res[N];
int ord,n,m;
void dfs(int u,int f)
{
L[u]=++ord,ori[ord]=u,sz[u]=1,son[u]=0;
EOR(i,G,u)
{
int v=G.to[i];
if(v==f)continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
R[u]=ord;
}
void add(int L,int R)
{
FOR(i,L,R)
num[++cnt[col[ori[i]]]]++;
}
void del(int L,int R)
{
FOR(i,L,R)
--num[cnt[col[ori[i]]]--];
}
void dsu(int u,int f)
{
EOR(i,G,u)
{
int v=G.to[i];
if(v!=f&&v!=son[u])dsu(v,u),del(L[v],R[v]);
}
if(son[u])dsu(son[u],u);
EOR(i,G,u)
{
int v=G.to[i];
if(v!=f&&v!=son[u])add(L[v],R[v]);
}
add(L[u],L[u]);
FOR(i,0,(int)ask[u].size()-1)
res[ask[u][i].id]=num[ask[u][i].k];
}
void clear()
{
memset(cnt,0,sizeof(cnt));
memset(num,0,sizeof(num));
G.clear();ord=0;
FOR(i,1,n)ask[i].clear();
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
clear();
FOR(i,1,n)scanf("%d",&col[i]);
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v);G.add(v,u);
}
dfs(1,0);
FOR(i,1,m)
{
int u,v;
scanf("%d%d",&u,&v);
ask[u].push_back((Query){i,v});
}
dsu(1,0);
FOR(i,1,m)printf("%d\n",res[i]);
}
return 0;
}