题意:给你一棵树,树上的每个节点都有权值,问以点u为根结点的子树里,权值刚好出现k次的有多少个。
首先把树型结构转化成线型结构(其实就是一次DFS)。线段树离线处理每个查询。
转化成线型结构后,对于每个查询,我们可以将以树型结构里u为根结点的子树转化成线型结构里的一段区间。问题就在,如果统计刚好出现k次。对于每个数,我们记录它在线型结构里出现的位置即p[](下标从0开始,|p|表示其元素个数),当|p|==k次的时间,我们将线段树里的结点p[|p|-k]加1,当|p|>k次的时候,我们将线段树里的结点p[|p|-t-1]减去2,并且|p|>k+1时,我们将线段树里的结点p[|p|-t-2]加1。
#include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <algorithm> #include <map> using namespace std; #define LL(x) (x<<1) #define RR(x) (x<<1|1) const int N=100005; struct Query { int lft,rht,id; Query(){} Query(int a,int b,int c){lft=a;rht=b;id=c;} bool operator < (const Query &b) const { return rht<b.rht; } }; struct node { int lft,rht,sum; int mid(){return lft+(rht-lft)/2;} }; struct Segtree { node tree[N*4]; void build(int lft,int rht,int ind) { tree[ind].lft=lft; tree[ind].rht=rht; tree[ind].sum=0; if(lft!=rht) { int mid=tree[ind].mid(); build(lft,mid,LL(ind)); build(mid+1,rht,RR(ind)); } } void updata(int pos,int ind,int valu) { tree[ind].sum+=valu; if(tree[ind].lft==tree[ind].rht) return; else { int mid=tree[ind].mid(); if(pos<=mid) updata(pos,LL(ind),valu); else updata(pos,RR(ind),valu); } } int getsum(int be,int end,int ind) { int lft=tree[ind].lft,rht=tree[ind].rht; if(be<=lft&&rht<=end) return tree[ind].sum; else { int sum1=0,sum2=0; int mid=tree[ind].mid(); if(be<=mid) sum1=getsum(be,end,LL(ind)); if(end>mid) sum2=getsum(be,end,RR(ind)); return sum1+sum2; } } }seg; int data1[N],data2[N],dfn,low[N],high[N],ans[N]; vector<int> adj[N],pos[N]; vector<Query> query; map<int,int> imap; void dfs(int u,int fa) { low[u]=++dfn; data2[dfn]=data1[u]; for(int i=0;i<(int)adj[u].size();i++) { int v=adj[u][i]; if(v==fa) continue; dfs(v,u); } high[u]=dfn; } void init(int n) { dfn=0; imap.clear(); query.clear(); for(int i=1;i<=n;i++) adj[i].clear(); for(int i=1;i<=n;i++) pos[i].clear(); } int main() { int t,t_cnt=0; scanf("%d",&t); while(t--) { int n,m,k,a,b,sc=0; scanf("%d%d",&n,&k); init(n+5); for(int i=1;i<=n;i++) { scanf("%d",&data1[i]); if(imap.find(data1[i])==imap.end()) imap.insert(make_pair(data1[i],++sc)); } for(int i=0;i<n-1;i++) { scanf("%d%d",&a,&b); adj[a].push_back(b); adj[b].push_back(a); } dfs(1,0); scanf("%d",&m); for(int i=0;i<m;i++) { scanf("%d",&a); query.push_back(Query(low[a],high[a],i)); } sort(query.begin(),query.end()); int ind=0; seg.build(1,n,1); for(int i=1;i<=dfn;i++) { int id=imap[data2[i]]; pos[id].push_back(i); int cnt=(int)pos[id].size(); if(cnt==k) { seg.updata(pos[id][cnt-k],1,1); } else if(cnt>k) { seg.updata(pos[id][cnt-k],1,1); seg.updata(pos[id][cnt-k-1],1,-2); if(cnt>k+1) seg.updata(pos[id][cnt-k-2],1,1); } while(query[ind].rht==i&&ind<m) { ans[query[ind].id]=seg.getsum(query[ind].lft,query[ind].rht,1); ind++; } } if(t_cnt!=0) puts(""); printf("Case #%d:\n",++t_cnt); for(int i=0;i<m;i++) printf("%d\n",ans[i]); } return 0; }