题目链接:点击打开链接
题意:给你一棵树,编号1~n,告诉你根结点是1。 每次有两个操作:
1,将以v为根的子树的结点全部染成颜色c
2,问以v为根的紫书的结点的颜色种类。
思路:如果这是一条线段的话, 那么这就是线段树的区间更新问题,而现在是一棵树。
因为告诉了根结点是1, 那么这棵树的任意一个结点的子树就是确定的, 所以我们可以用DFS的先序遍历,将所有结点重新编号,因为先序遍历的话, 任意一个结点和其子树的编号就是一条连续的线段了,在这其中维护每个结点的新编号, 和这个结点的子树中的最大编号即可。
然后就是线段树区间更新了, 由于颜色数最大60, 用long long通过位运算的 | 操作就行了, 注意对1左移的时候应该先将1转成long long再进行操作。
细节参见代码:
#include<cstdio> #include<cstring> #include<algorithm> #include<iostream> #include<string> #include<vector> #include<stack> #include<bitset> #include<cstdlib> #include<cmath> #include<set> #include<list> #include<deque> #include<map> #include<queue> #define Max(a,b) ((a)>(b)?(a):(b)) #define Min(a,b) ((a)<(b)?(a):(b)) using namespace std; typedef long long ll; const double PI = acos(-1.0); const double eps = 1e-6; const int mod = 1000000000 + 7; const int INF = 1000000000; const int maxn = 400000 + 10; int T,n,m,u,v,id[maxn],a[maxn],cnt,last[maxn],b[maxn],setv[maxn<<2]; bool vis[maxn]; ll sum[maxn<<2]; vector<int> g[maxn]; void dfs(int root) { id[root] = ++cnt; vis[root] = true; int len = g[root].size(); for(int i=0;i<len;i++) { int v = g[root][i]; if(!vis[v]) { dfs(v); } } last[root] = cnt; } void PushUp(int o) { sum[o] = sum[o<<1] | sum[o<<1|1]; } void pushdown(int l, int r, int o) { if(setv[o]) { setv[o<<1] = setv[o<<1|1] = setv[o]; sum[o<<1] = sum[o<<1|1] = (1LL<<setv[o]); setv[o] = 0; } } void build(int l, int r, int o) { int m = (l + r) >> 1; setv[o] = 0; if(l == r) { sum[o] = 1LL<<b[++cnt]; return ; } build(l, m, o<<1); build(m+1, r, o<<1|1); PushUp(o); } void update(int L, int R, int v, int l, int r, int o) { int m = (l + r) >> 1; if(L <= l && r <= R) { setv[o] = v; sum[o] = (1LL << v); return ; } pushdown(l, r, o); if(L <= m) update(L, R, v, l, m, o<<1); if(m < R) update(L, R, v, m+1, r, o<<1|1); PushUp(o); } ll query(int L, int R, int l, int r, int o) { int m = (l + r) >> 1; if(L <= l && r <= R) { return sum[o]; } pushdown(l, r, o); ll ans = 0; if(L <= m) ans |= query(L, R, l, m, o<<1); if(m < R) ans |= query(L, R, m+1, r, o<<1|1); PushUp(o); return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); g[u].clear(); } for(int i=1;i<n;i++) { scanf("%d%d",&u,&v); g[u].push_back(v); g[v].push_back(u); } memset(vis, 0, sizeof(vis)); cnt = 0; dfs(1); for(int i=1;i<=n;i++) { b[id[i]] = a[i]; } cnt = 0; build(1, n, 1); int res, v, c; while(m--) { scanf("%d",&res); if(res == 1) { scanf("%d%d",&v,&c); update(id[v], last[v], c, 1, n, 1); } else { scanf("%d",&v); ll ans = query(id[v], last[v], 1, n, 1); int cc = 0; for(int i=1;i<=61;i++) { if(ans & (1LL<<i)) cc++; } printf("%d\n",cc); } } return 0; }