N个节点的树,有R种属性,每个点属于一种属性。有Q次询问,每次询问r1,r2,回答有多少对(e1,e2)满足e1属性是r1,e2属性是r2,e1是e2的祖先。
数据规模
N≤200000,R≤25000,Q≤200000
N≤200000,R≤25000,Q≤200000
30%数据R≤500
55%数据同种属性节点个数≤500
HOME Back
参考于《根号算法——不只是分块》 王悦同
设属性r1的点有A个属性r2的点有B个
若A很小,则设计一个AlogB的算法,若B很小,则设计一个BlogA的算法
若都很大,这样的询问不会很多,就直接O(N)暴力
针对不同类型询问设计不同的专杀算法,运用根号卡时间
苟蒻编写能力还有待提高!注意数组下标。。。
#include<iostream> #include<cstdio> #include<cstdlib> #include<algorithm> #include<cstring> #include<cmath> #include<vector> #include<queue> using namespace std; const int maxn = 2E5 + 20; struct Q{ int r1,r2,num; bool operator < (const Q &b) const { if (r1 < b.r1) return 1; if (r1 > b.r1) return 0; return r2 < b.r2; } bool operator == (const Q &b) const { return (r1 == b.r1 && r2 == b.r2); } }q[maxn]; struct P{ int pos,add,sum; bool operator < (const P &b) const { return pos < b.pos; } }; vector <int> v[maxn],v2[maxn]; vector <P> v3[maxn]; int mark[maxn],dfs[maxn],n,m,ans[maxn],dfs_clock = 0,flag; int head[maxn],tail[maxn]; int getint() { int ret = 0; char ch = getchar(); while (ch < '0' || ch > '9') ch = getchar(); while ('0' <= ch && ch <= '9') { ret = ret*10 + ch - '0'; ch = getchar(); } return ret; } void DFS(int k) { dfs[k] = head[k] = ++dfs_clock; v2[mark[k]].push_back(k); for (int i = 0; i < v[k].size(); i++) DFS(v[k][i]); tail[k] = dfs_clock; if (head[k] != tail[k]) { v3[mark[k]].push_back((P){head[k],1,1}); v3[mark[k]].push_back((P){tail[k]+1,-1,-1}); } } int DBSC1(int x,int l,int r,int pos) { if (r - l <= 1) { if (dfs[v2[x][l]] >= pos) return l; else return r; } int mid = (l+r) >> 1; if (dfs[v2[x][mid]] >= pos) return DBSC1(x,l,mid,pos); else return DBSC1(x,mid,r,pos); } int DBSC2(int x,int l,int r,int pos) { if (r - l <= 1) { if (dfs[v2[x][r]] <= pos) return r; else return l; } int mid = (l+r) >> 1; if (dfs[v2[x][mid]] <= pos) return DBSC2(x,mid,r,pos); else return DBSC2(x,l,mid,pos); } int DBSC(int x,int l,int r,int pos) { if (r - l <= 1) { if (v3[x][r].pos <= pos) return r; else return l; } int mid = (l+r) >> 1; if (v3[x][mid].pos <= pos) return DBSC(x,mid,r,pos); else return DBSC(x,l,mid,pos); } void solve1(int x) { int ANS = 0; int r1 = q[x].r1; int r2 = q[x].r2; for (int i = 0; i < v2[r1].size(); i++) { int now = v2[r1][i]; if (head[now] == tail[now]) continue; if (tail[now] < dfs[v2[r2][0]]) continue; if (head[now] > dfs[v2[r2][v2[r2].size()-1]]) continue; int L = DBSC1(r2,0,v2[r2].size()-1,head[now]); int R = DBSC2(r2,0,v2[r2].size()-1,tail[now]); ANS += R-L+1; } ans[q[x].num] = ANS; } void solve2(int x) { int ANS = 0; int r1 = q[x].r1; int r2 = q[x].r2; for (int i = 0; i < v2[r2].size(); i++) { int now = v2[r2][i]; if (dfs[now] < v3[r1][0].pos) continue; if (dfs[now] > v3[r1][v3[r1].size()-1].pos) continue; /*for (int i = 0; i < v3[r1].size(); i++) { int pos = v3[r1][i].pos; int add = v3[r1][i].add; int sum = v3[r1][i].sum; int bbbb = 1; }*/ int POS = DBSC(r1,0,v3[r1].size()-1,dfs[now]); ANS += v3[r1][POS].sum; } ans[q[x].num] = ANS; } void solve3(int x) { int ANS,SUM,L,R; int r1 = q[x].r1; int r2 = q[x].r2; ANS = SUM = L = R = 0; while (L < v3[r1].size() && R < v2[r2].size()) { if (v3[r1][L].pos <= dfs[v2[r2][R]]) { SUM += v3[r1][L].add; L++; } else { ANS += SUM; R++; } } //while (R < v2[r2].size()) ANS += SUM,R++; ans[q[x].num] = ANS; } int main() { #ifdef YZY freopen("yzy.txt","r",stdin); #endif int tt; cin >> n >> tt >> m >> mark[1]; flag = sqrt(n); for (int i = 2; i <= n; i++) { int x,y; x = getint(); y = getint(); v[x].push_back(i); mark[i] = y; } DFS(1); /*for (int i = 1; i <= tt; i++) { for (int j = 0; j < v2[i].size(); j++) { int kk = v2[i][j]; int b = 1; } for (int j = 0; j < v3[i].size(); j++) { int pos = v3[i][j].pos; int add = v3[i][j].add; int sum = v3[i][j].sum; int b = 1; } }*/ for (int i = 1; i <= tt; i++) { sort(v3[i].begin(),v3[i].end()); for (int j = 1; j < v3[i].size(); j++) v3[i][j].sum += v3[i][j-1].sum; } for (int i = 1; i <= m; i++) { int x,y; x = getint(); y = getint(); q[i] = (Q){x,y,i}; } sort(q+1,q+m+1); for (int i = 1; i <= m; i++) { if (q[i] == q[i-1]) { ans[q[i].num] = ans[q[i-1].num]; continue; } int S1 = v2[q[i].r1].size(); int S2 = v2[q[i].r2].size(); if (S1 <= flag && S2 <= flag) { solve3(i); continue; } if (S1 <= flag) { solve1(i); continue; } if (S2 <= flag) { solve2(i); continue; } solve3(i); } for (int i = 1; i <= m; i++) printf("%d\n",ans[i]); return 0; }