线段树+位运算
首先对树进行DFS,写出DFS序列,记录下每一个节点控制的区间范围。然后就是区间更新和区间查询了。
某段区间的颜色种类可以用位运算来表示,方便计算。
如果仅有第i种颜色,那么就用十进制数(1<
如果A区间有的颜色是col1,B区间有的颜色是col2,合并之后有的就是(col1 | col2)
输出有几种,就是看得到的十进制数的二进制表示中有多少位是1.
#include #include #include #include #include #include #include<string> #include #include using namespace std; const int maxn=4*100000+10; int n,q; int col[maxn]; vector<int> Tree[maxn]; bool b[maxn]; int Left[maxn],Right[maxn]; int s[2*maxn]; int time; long long ans; struct SegTree { bool flag; long long ans; }segTree[2*maxn*4]; void dfs(int now) { b[now]=1; Left[now]=(++time); s[time]=now; for(int i=0;i) if(b[Tree[now][i]]==0) dfs(Tree[now][i]); Right[now]=(++time); s[time]=now; } void pushUp(int rt) { segTree[rt].ans=(segTree[2*rt].ans|segTree[2*rt+1].ans); } void pushDown(int rt) { if(segTree[rt].flag!=0) { segTree[2*rt].flag=segTree[2*rt+1].flag=segTree[rt].flag; segTree[2*rt].ans=segTree[2*rt+1].ans=segTree[rt].ans; segTree[rt].flag=0; } } void build(int l,int r,int rt) { if(l==r) { segTree[rt].flag=0; segTree[rt].ans=(long long)1<<((long long)col[s[l]]); return ; } int m=(l+r)/2; if(l<=m) build(l,m,2*rt); if(r>m) build(m+1,r,2*rt+1); pushUp(rt); return; } void quary(int L,int R,int l,int r,int rt) { if(L<=l&&r<=R) { ans=(ans|segTree[rt].ans); return; } pushDown(rt); int m=(l+r)/2; if(L<=m) quary(L,R,l,m,2*rt); if(R>m) quary(L,R,m+1,r,2*rt+1); pushUp(rt); return; } void update(int info,int L,int R,int l,int r,int rt) { if(L<=l&&r<=R) { segTree[rt].flag=info; segTree[rt].ans=(long long)1<<(long long) info; return; } pushDown(rt); int m=(l+r)/2; if(L<=m) update(info,L,R,l,m,2*rt); if(R>m) update(info,L,R,m+1,r,2*rt+1); pushUp(rt); } int main() { scanf("%d%d",&n,&q); for(int i=1;i<=n;i++) scanf("%d",&col[i]); for(int i=0;i<=n;i++) Tree[i].clear(); for(int i=1;i<=n-1;i++) { int x,y; scanf("%d%d",&x,&y); Tree[x].push_back(y); Tree[y].push_back(x); } memset(b,0,sizeof b); time=0,dfs(1); build(1,2*n,1); for(int i=1;i<=q;i++) { int tk; scanf("%d",&tk); if(tk==1) { int vk,ck; scanf("%d%d",&vk,&ck); update(ck,Left[vk],Right[vk],1,2*n,1); } else { int vk; scanf("%d",&vk); ans=0; quary(Left[vk],Right[vk],1,2*n,1); int num=0; while(ans) { if(ans%2==1) num++; ans=ans/2; } printf("%d\n",num); } } return 0; }
转载于:https://www.cnblogs.com/zufezzt/p/5151188.html