题意:有N个节点,组成一棵树,1号节点是根节点。每个节点都有一个权值。现在有q个查询,每次查询根节点是U的子树中有多少个不同的权值恰好出现K次。
思路:第一眼让人感觉是图论,但仔细分析发现单纯的搜索肯定会超时。
这题分为另个部分:
1.以u为根节点的子树有多少种不同的权值。
2.权值为k的有多少。
对于线性列表,查询一个区间有多少个不同的子节点,我们用一个pre[v]记录权值v上次出现的位置。那么当第i个数的权值为v时,区间pre[v]+1-->i这个区间就会多一个不同的数,那么可以用update(pre[v]+1,1)来更新。但是i之后的状态不能改变,故再用一个update(i+1,-1)。这样就保证了只有pre[v]+1-->i的区间改变了。
所以,我们要对所有查询区间按右边界升序排序。当插入的值i为第j个查询的右边界时,其不同元素的个数就是对该查询的左节点求和,相当于该查询区间内有多少个不同的数v和pre[v]的连线与qt[i].r和qt[i].l的连线相交,若相交,必包含该查询的左节点qt[i].l,故只需对qt[i].l做求和操作。
完成上面的问题,那么我们就来解决现在问题。首先要将所有的权值离散化,这样才能使用pre[v]。离散化过后我们从1号节点开始对树遍历一遍:
void dfs(int u) { L[u]=++id;list[id]=art[u];//每个节点的左边界就是本身 int i; for(i=index[u];i!=-1;i=edge[i].next) { if(!L[edge[i].to]) dfs(edge[i].to); } R[u]=id;//右边界就是子树中id编号最大的 }
这样我们就将查找以u为根节点转化为第u个区间的问题。
然后然后,就是查找了。
这里的pre[v][k]表示权值为v的第k次出现的位置。
对于第i个元素来说,如果该节点的权值v出现的次数超过K,那么pre[v][i-k-1]+1,到pre[v-k]的区间的k出现次数就会超过k,那么用update(pre[v][i-k-1]+1,-1);
并且要用update(pre[v][i-k]+1,1);以确保之后的区间不受影响。
具体见代码:
#include<iostream> #include<cstdio> #include<algorithm> #include<map> #include<vector> #include<cstring> using namespace std; struct QT{ int l,r,id; bool operator <(const QT &temp) const { return r<temp.r; } }qt[100010]; vector<int> pre[100010]; struct Edge{ int to,next; }edge[300020]; int n,c[100010],e,index[100010],list[100010],art[100010],id=0; int L[100010],R[100010]; int ans[100010]; void addedge(int from,int to) { edge[e].to=to; edge[e].next=index[from]; index[from]=e++; edge[e].to=from; edge[e].next=index[to]; index[to]=e++; } int cmp(int x,int y) { return art[x]<art[y]; } void update(int pos,int val) { while(pos<=n) { c[pos]+=val; pos+=pos&(-pos); //cout<<"ok"<<endl; } } int Sum(int i) { int s=0; while(i) { s+=c[i]; i-=i&(-i); } return s; } void dis() { int r[100010],i; int temp=-1; for(i=1;i<=n;i++) r[i]=i; sort(r+1,r+1+n,cmp); int prev=art[r[1]]-1; for(i=1;i<=n;i++) { if(art[r[i]]!=prev) prev=art[r[i]],art[r[i]]=++temp; else art[r[i]]=temp; } for(i=0;i<=temp;i++) { pre[i].clear(); pre[i].push_back(0); } } void dfs(int u) { L[u]=++id;list[id]=art[u]; int i; for(i=index[u];i!=-1;i=edge[i].next) { if(!L[edge[i].to]) dfs(edge[i].to); } R[u]=id; } void init() { memset(c,0,sizeof(c)); memset(L,0,sizeof(L)); memset(R,0,sizeof(R)); memset(index,-1,sizeof(index)); id=0; e=0; } int main() { int i,j,q,cas,k,t,a,b; scanf("%d",&t); cas=0; while(t--) { init(); scanf("%d%d",&n,&k); for(i=1;i<=n;i++) scanf("%d",&art[i]); dis(); for(i=1;i<n;i++) { scanf("%d%d",&a,&b); addedge(a,b); } dfs(1); scanf("%d",&q); int x; for(i=1;i<=q;i++) { scanf("%d",&x); qt[i].l=L[x]; qt[i].r=R[x]; qt[i].id=i; } sort(qt+1,qt+q+1); j=1; // for(i=1;i<=n;i++) {//cout<<"ok"<<endl; int v=list[i];pre[v].push_back(i); int g=pre[v].size()-1; if(g>=k) { if(g>k) { update(pre[v][g-k-1]+1,-1); update(pre[v][g-k]+1,1); } update(pre[v][g-k]+1,1); update(pre[v][g-k+1]+1,-1); } while(qt[j].r==i) { ans[qt[j].id]=Sum(qt[j].l); j++; } } printf("Case #%d:\n",++cas); for(i=1;i<=q;i++) { printf("%d\n",ans[i]); } if(t) printf("\n"); } return 0; }