给定一个 n 个节点, n−1 条边的树。有 m 个玩家,第 i 个玩家从 xi 走树上最短路径到 yi 。玩家第 0 秒在自己的起点上,然后每秒移动一条边,移动到终点后结束移动。
每个节点上有一个观察员以及权值 wi 。如果有一个玩家在其移动的第 wi 秒恰好到达这个点,那么这个点上的观察员就会观察到他(如果这个点是终点,且玩家在 wi 秒之前到达不算)。
求每个点上的观察员分别观察到了多少个玩家。
1≤n≤299998,1≤m≤299998
我们将每条路径拆成从出发点到 lca ,和从 lca 的下一个点到结束点两段。
对于第一段,如果在点 i 时被观察到,我们可以写出 deep(x)−deep(i)=wi 。因此我们对所有 deep(x) 开权值线段树(下标为 DFN(x) ,动态开点),在每个玩家路径 lca 处挂一个在 deep(x) 的权值线段树中 DFN(x) 位置权值+1的标记。然后每访问一个点就要先处理该点的所有标记,在 deep(i)+wi 的权值线段树里面查询子树的权值和。但是因为在 lca 上面的点就不能观察到这个点,因此退出一个点时要撤销上面的标记。
对于第二段,如果点 i 时被观察到了,我们可以写出 deep(x)+deep(i)−2deep(lca(x,i))=wi ,和上面类似,那么我们对所有路径的 deep(x)−2deep(lca(x,y)) 开权值线段树(依然下标为 DFN(x) ,动态开点)。然后将标记挂在 lca 靠 y 的儿子那里(为了避免算重)。然后在 wi−deep(i) 的权值线段树中查询。
这样就要实现一个倍增和常数比较大的权值线段树, 时间复杂度 O(nlog2n) 。
听说CCF老爷机要卡log算法的常,那我也没办法了。
update:将我的权值线段树改成桶就可以 O(n) 了QwQ。
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cmath>
using namespace std;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
int buf[30];
void write(int x)
{
if (x<0) putchar('-'),x=-x;
for (;x;x/=10) buf[++buf[0]]=x%10;
if (!buf[0]) buf[++buf[0]]=0;
for (;buf[0];buf[0]--) putchar('0'+buf[buf[0]]);
}
const int N=300050;
const int LGN=19;
const int M=300050;
const int E=N<<1;
const int Q=M<<1;
struct query
{
bool tp;
int key,x;
}qy[Q];
int qnxt[Q];
struct D
{
int key,x,y,lab;
}srt[M];
int d;
bool operator<(D a,D b){return a.key<b.key;}
struct segment_tree
{
int sum[N*20],son[N*20][2],root[M];
int tot;
int newnode()
{
sum[++tot]=0;
son[tot][0]=son[tot][1]=0;
return tot;
}
void update(int x){sum[x]=sum[son[x][0]]+sum[son[x][1]];}
void modify(int &x,int y,int l,int r,int delta)
{
if (!x) x=newnode();
if (l==r)
{
sum[x]+=delta;
return;
}
int mid=l+r>>1;
if (y<=mid) modify(son[x][0],y,l,mid,delta);
else modify(son[x][1],y,mid+1,r,delta);
update(x);
}
int query(int x,int st,int en,int l,int r)
{
if (!x) return 0;
if (st==l&&en==r) return sum[x];
int mid=l+r>>1;
if (en<=mid) return query(son[x][0],st,en,l,mid);
else if (mid+1<=st) return query(son[x][1],st,en,mid+1,r);
else return query(son[x][0],st,mid,l,mid)+query(son[x][1],mid+1,en,mid+1,r);
}
}t[2];
int deep[N],last[N],size[N],DFN[N],qlst[N],w[N],ans[N];
int tov[E],nxt[E];
int fa[N][LGN];
int n,m,tot,idx,lgn,cnt,dif;
void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}
void hang(int x,int y){qnxt[y]=qlst[x],qlst[x]=y;}
void dfs(int x)
{
DFN[x]=++idx,size[x]=1;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x][0])
fa[y][0]=x,deep[y]=deep[x]+1,dfs(y),size[x]+=size[y];
}
void pre()
{
lgn=trunc(log(n)/log(2));
for (int j=1;j<=lgn;j++)
for (int i=1;i<=n;i++)
fa[i][j]=fa[fa[i][j-1]][j-1];
}
int adjust(int x,int d)
{
for (int i=lgn;i>=0;i--) if (deep[fa[x][i]]>=d) x=fa[x][i];
return x;
}
int lca(int x,int y)
{
if (deep[x]>deep[y]) swap(x,y);
y=adjust(y,deep[x]);
if (x==y) return x;
for (int i=lgn;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void proc()
{
sort(srt+1,srt+1+d);
dif=0,srt[0].key=-n*2;
for (int i=1;i<=d;i++)
{
dif+=srt[i].key!=srt[i-1].key;
srt[i].lab=dif;
qy[++cnt].tp=1,qy[cnt].key=dif,qy[cnt].x=srt[i].y;
hang(srt[i].x,cnt);
}
}
query tmp;
int search(int aim)
{
int ret=0;
int l=1,r=d,mid;
while (l<=r)
{
mid=l+r>>1;
if (srt[mid].key<=aim) ret=mid,l=mid+1;
else r=mid-1;
}
if (srt[ret].key==aim) return srt[ret].lab;
return 0;
}
//int total=0;
void calc(int x)
{
for (int i=qlst[x];i;i=qnxt[i])
{
tmp=qy[i];
//printf("0 %d %d\n",x,i);
t[tmp.tp].modify(t[tmp.tp].root[tmp.key],DFN[tmp.x],1,n,1);
}
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x][0]) calc(y);
ans[x]=0;
if (w[x]+deep[x]<=n) ans[x]+=t[0].query(t[0].root[w[x]+deep[x]],DFN[x],DFN[x]+size[x]-1,1,n);
int f=search(deep[x]-w[x]);
if (f) ans[x]+=t[1].query(t[1].root[f],DFN[x],DFN[x]+size[x]-1,1,n);
for (int i=qlst[x];i;i=qnxt[i])
{
tmp=qy[i];
/*if (x==3&&i==555634) { tmp.tp=qy[i].tp; }*/
//printf("1 %d %d\n",x,i);
t[tmp.tp].modify(t[tmp.tp].root[tmp.key],DFN[tmp.x],1,n,-1);
//printf("%d\n",++total);
}
}
int main()
{
freopen("running.in","r",stdin),freopen("running.out","w",stdout);
n=read(),m=read();
for (int i=1,x,y;i<n;i++)
{
x=read(),y=read();
insert(x,y),insert(y,x);
}
fa[1][0]=0,deep[1]=1,dfs(1),pre();
for (int i=1;i<=n;i++) w[i]=read();
for (int x,y,z;m--;)
{
x=read(),y=read(),z=lca(x,y);
qy[++cnt].tp=0,qy[cnt].key=deep[x],qy[cnt].x=x;
hang(z,cnt);
if (!(DFN[y]<=DFN[x]&&DFN[x]<=DFN[y]+size[y]-1)) srt[++d].key=deep[z]*2-deep[x],srt[d].x=adjust(y,deep[z]+1),srt[d].y=y;
}
proc(),calc(1);
for (int i=1;i<n;i++) write(ans[i]),putchar(' ');
write(ans[n]),putchar('\n');
fclose(stdin),fclose(stdout);
return 0;
}