给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
鸣谢seter
#include<iostream> #include<cstdio> #include<algorithm> #include<vector> #include<cstring> using namespace std; const int maxn = 2E5 + 10; int n,m,a[maxn],root[maxn],c[maxn*40],lc[maxn*40],rc[maxn*40]; int cur = 1,ans,cnt,num[maxn],anc[maxn][20],L[maxn]; vector <int> v[maxn]; int Insert(int o,int l,int r,int pos) { int ret = ++cnt; c[ret] = c[o] + 1; if (l == r) return ret; int mid = (l+r) >> 1; if (pos <= mid) { rc[ret] = rc[o]; lc[ret] = Insert(lc[o],l,mid,pos); } else { lc[ret] = lc[o]; rc[ret] = Insert(rc[o],mid+1,r,pos); } return ret; } int getint() { int ret = 0; char ch = getchar(); while (ch < '0' || ch > '9') ch = getchar(); while (ch >= '0' && ch <= '9') { ch -= '0'; ret = ret*10 + ch; ch = getchar(); } return ret; } void dfs(int x,int fa) { root[x] = Insert(root[anc[x][0]],1,cur,a[x]); for (int i = 0; i < v[x].size(); i++) { int to = v[x][i]; if (to == fa) continue; anc[to][0] = x; L[to] = L[x] + 1; dfs(to,x); } } void build(int o,int l,int r) { c[o] = 0; if (l == r) return; int mid = (l+r) >> 1; lc[o] = ++cnt; rc[o] = ++cnt; build(lc[o],l,mid); build(rc[o],mid+1,r); } int LCA(int p,int q) { if (L[p] < L[q]) swap(p,q); int log; for (log = 1; L[p] - (1<<log) >= 0; log++); --log; for (int j = log; j >= 0; j--) if (L[p] - (1<<j) >= L[q]) p = anc[p][j]; if (p == q) return p; for (int j = log; j >= 0; j--) if (anc[p][j] != anc[q][j]) { p = anc[p][j]; q = anc[q][j]; } return anc[p][0]; } int query(int o1,int o2,int f1,int f2,int l,int r,int pos) { if (l == r) return num[l]; int mid = (l+r) >> 1; int tot = c[lc[o1]] + c[lc[o2]] - c[lc[f1]] - c[lc[f2]]; if (tot >= pos) return query(lc[o1],lc[o2],lc[f1],lc[f2],l,mid,pos); else return query(rc[o1],rc[o2],rc[f1],rc[f2],mid+1,r,pos-tot); } int main() { #ifdef YZY freopen("data0.in","r",stdin); freopen("yzy.txt","w",stdout); #endif n = getint(); m = getint(); for (int i = 1; i <= n; i++) { a[i] = getint(); num[i] = a[i]; } sort(num + 1,num + n + 1); for (int i = 2; i <= n; i++) if (num[i] != num[i-1]) num[++cur] = num[i]; for (int i = 1; i < n; i++) { int x,y; x = getint(); y = getint(); v[x].push_back(y); v[y].push_back(x); } root[0] = ++cnt; build(1,1,cur); for (int i = 1; i <= n; i++) a[i] = lower_bound(num + 1,num + cur + 1,a[i]) - num; L[1] = 0; dfs(1,0); for (int i = 1; i < 20; i++) for (int j = 1; j <= n; j++) anc[j][i] = anc[anc[j][i-1]][i-1]; for (int i = 1; i <= m; i++) { int x,y,k; x = getint(); y = getint(); k = getint(); x ^= ans; int lca = LCA(x,y); ans = query(root[x],root[y],root[lca],root[anc[lca][0]],1,cur,k); if (i < m) printf("%d\n",ans); else cout << ans; } return 0; }