询问树上两点的路径上小于/等于/大于所给定值k得点的数量(每次询问k不一定相同)
强制在线
以计算小于k为例:
u到v的路径上小于k的点的数量=>u到根节点小于k的点的数量+v到根节点小于k的点的数量-2*lca(u,v)到根节点小于k的点的数量+lca(u,v)的权值是否小于k
然后裸的主席树
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int maxn=300000+10,maxm=200000+10;
int root[maxn],left[maxn*20],right[maxn*20],sum[maxn*20],a[maxn];
int h[maxn],go[maxn*2],next[maxn*2];
int d[maxn],f[maxn][20];
bool czy;
int i,j,k,l,t,n,m,q,tot,now,ans,u,v,w;
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void insert(int &x,int y,int l,int r,int a){
x=++tot;
left[x]=left[y];
right[x]=right[y];
sum[x]=sum[y]+1;
if (l==r) return;
int mid=(l+r)/2;
if (a<=mid) insert(left[x],left[y],l,mid,a);else insert(right[x],right[y],mid+1,r,a);
}
void dfs(int x,int y){
d[x]=d[y]+1;
f[x][0]=y;
insert(root[x],root[y],1,m,a[x]);
int t=h[x];
while (t){
if (go[t]!=y) dfs(go[t],x);
t=next[t];
}
}
int lca(int x,int y){
if (d[x]<d[y]) swap(x,y);
int j;
if (d[x]!=d[y]){
j=floor(log(d[x])/log(2));
while (j>=0){
if (d[f[x][j]]>d[y]) x=f[x][j];
j--;
}
x=f[x][0];
}
if (x==y) return x;
j=floor(log(d[x])/log(2));
while (j>=0){
if (f[x][j]!=f[y][j]){
x=f[x][j];
y=f[y][j];
}
j--;
}
return f[x][0];
}
int query(int x,int l,int r,int a,int b){
if (a>b||!x) return 0;
if (l==a&&r==b) return sum[x];
int mid=(l+r)/2;
if (b<=mid) return query(left[x],l,mid,a,b);
else if (a>mid) return query(right[x],mid+1,r,a,b);
else return query(left[x],l,mid,a,mid)+query(right[x],mid+1,r,mid+1,b);
}
int main(){
czy=1;
scanf("%d%d%d",&n,&m,&q);
fo(i,1,n) scanf("%d",&a[i]);
fo(i,1,n-1){
scanf("%d%d",&j,&k);
add(j,k);add(k,j);
}
tot=0;
dfs(1,0);
fo(j,1,floor(log(n)/log(2)))
fo(i,1,n)
f[i][j]=f[f[i][j-1]][j-1];
while (q--){
scanf("%d%d%d",&u,&v,&k);
if (czy) u^=now,v^=now,k^=now;
now=0;
w=lca(u,v);
ans=query(root[u],1,m,1,k-1)+query(root[v],1,m,1,k-1)-2*query(root[w],1,m,1,k-1);
if (a[w]<k) ans++;
printf("%d ",ans);
now^=ans;
ans=query(root[u],1,m,k,k)+query(root[v],1,m,k,k)-2*query(root[w],1,m,k,k);
if (a[w]==k) ans++;
printf("%d ",ans);
now^=ans;
ans=query(root[u],1,m,k+1,m)+query(root[v],1,m,k+1,m)-2*query(root[w],1,m,k+1,m);
if (a[w]>k) ans++;
printf("%d\n",ans);
now^=ans;
}
}