给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
M行,表示每个询问的答案。
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
2
8
9
105
7
HINT:
N,M<=100000
暴力自重。。。
鸣谢seter
写的好爽!
查询的时候不能减两遍LCA那个点,因为LCA就没被计算了…也不能减两遍LCA的父亲,那样LCA就被计算两次…正确做法是LCA减一次,它的父亲减一次…
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
typedef long long LL;
const int SZ1 = 200010;
const int SZ2 = 2000010;
const int INF = 1000000010;
int n,m,len;
int head[SZ1],nxt[SZ2],to[SZ2],tot = 1;
void build(int f,int t)
{
to[++ tot] = t;
nxt[tot] = head[f];
head[f] = tot;
}
int val[SZ1],anc[SZ1][20],deep[SZ1];
struct node{
int l,r;
int cnt;
}tree[SZ2];
int Tcnt = 0;
int rt[SZ2];
void insert(int l,int r,int last,int &now,int x)
{
now = ++ Tcnt;
tree[now] = tree[last];
tree[now].cnt ++;
if(l == r) return ;
int mid = l + r >> 1;
if(x <= mid) insert(l,mid,tree[last].l,tree[now].l,x);
else insert(mid + 1,r,tree[last].r,tree[now].r,x);
}
int ask(int u,int v,int lca,int k)
{
int r1 = rt[u],r2 = rt[v],l1 = rt[anc[lca][0]],l2 = rt[lca];
int L = 1,R = len;
while(L < R)
{
int mid = L + R >> 1;
int d = tree[tree[r1].l].cnt + tree[tree[r2].l].cnt - tree[tree[l1].l].cnt - tree[tree[l2].l].cnt;
if(d >= k)
R = mid,r1 = tree[r1].l,r2 = tree[r2].l,l1 = tree[l1].l,l2 = tree[l2].l;
else
k -= d,L = mid + 1,r1 = tree[r1].r,r2 = tree[r2].r,l1 = tree[l1].r,l2 = tree[l2].r;
}
return L;
}
void dfs(int u,int fa)
{
insert(1,len,rt[fa],rt[u],val[u]);
deep[u] = deep[fa] + 1;
anc[u][0] = fa;
for(int i = 1;anc[u][i - 1];i ++)
anc[u][i] = anc[anc[u][i - 1]][i - 1];
for(int i = head[u];i;i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs(v,u);
}
}
int ask_lca(int u,int v)
{
if(deep[u] < deep[v]) swap(u,v);
if(deep[u] > deep[v])
{
int dd = deep[u] - deep[v];
for(int i = 16;i >= 0;i --)
if(dd >> i & 1)
u = anc[u][i];
}
if(u != v)
for(int i = 16;i >= 0;i --)
if(anc[u][i] != anc[v][i])
u = anc[u][i],v = anc[v][i];
if(u == v) return u;
return anc[u][0];
}
int lsh[SZ2];
int main()
{
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i ++)
scanf("%d",&val[i]),lsh[++ lsh[0]] = val[i];
sort(lsh + 1,lsh + 1 + lsh[0]);
len = unique(lsh + 1,lsh + 1 + lsh[0]) - lsh - 1;
for(int i = 1;i <= n;i ++)
val[i] = lower_bound(lsh + 1,lsh + 1 + len,val[i]) - lsh;
for(int i = 1;i <= n - 1;i ++)
{
int x,y;
scanf("%d%d",&x,&y);
build(x,y); build(y,x);
}
dfs(1,0);
int lastans = 0;
for(int i = 1;i <= m;i ++)
{
int u,v,k;
scanf("%d%d%d",&u,&v,&k);
u ^= lastans;
lastans = lsh[ask(u,v,ask_lca(u,v),k)];
printf("%d",lastans);
if(i != m) puts("");
}
return 0;
}