[JZOJ5898]【NOIP2018模拟10.6】距离统计

Description

给定一棵n个节点的带边权树,m组询问,每次询问两个数u,k,求出u本身外到u的第k小距离(相等距离会算多次)
n,m<=50000

Solution

这绝对假NOIP。。

首先肯定是二分答案,将问题转化为判定性问题,求有多少个距离小于mid的

把点分治树构出来,对于每个节点弄出以它为分治中心(点分树上以它为根的子树)的节点到它的距离,排好序。

查询某一个点的某一个距离,只需要从这个点开始向点分树上父亲跳,跳到一个父亲就二分查询个数。

这样有可能会算重,跳上来的那个子树应该被减掉。

只需要在点分树上记它的子树中所有节点到它父亲的距离,跳的时候在这个数组中一样二分,减掉即可。

复杂度 O ( n log ⁡ 3 n ) O(n\log^3 n ) O(nlog3n)(简直丧心病狂)

Code

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define N 50005
#define M 30*N
using namespace std;
int fs[N],nt[2*N],dt[2*N],n,m,q,pr[2*N],n1,sz[N],a1[N][2],d[M],dis[21][N],mi,mw,fa[N],dep[N],d2[M];
bool bz[N];
void link(int x,int y,int z)
{
	nt[++m]=fs[x];
	dt[fs[x]=m]=y;
	pr[m]=z;
}
void dfs(int k,int fa,int num)
{
	sz[k]=1;
	int mx=0;
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa&&!bz[p]) dfs(p,k,num),sz[k]+=sz[p],mx=max(mx,sz[p]);
	}
	if(max(mx,num-sz[k])<mi) mi=max(mx,num-sz[k]),mw=k;
}
void make(int k,int fa,int st)
{
	d[++d[0]]=dis[dep[st]][k];
	d2[d[0]]=dis[dep[st]-1][k];
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa&&!bz[p])
		{
			dis[dep[st]][p]=dis[dep[st]][k]+pr[i];
			make(p,k,st);
		}		
	}
}
void doit(int k,int num)
{
	dfs(k,0,num);
	dis[dep[k]][k]=0;
	a1[k][0]=d[0]+1;
	make(k,0,k);
	bz[k]=1;
	a1[k][1]=d[0];
	sort(d+a1[k][0],d+a1[k][1]+1);
	sort(d2+a1[k][0],d2+a1[k][1]+1);
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(!bz[p])
		{
			mi=sz[p]+1;
			dfs(p,k,sz[p]);
			fa[mw]=k;
			dep[mw]=dep[k]+1;
			doit(mw,sz[p]);
		}
	}
}
int ct(int x,int lim)
{
	int s=-1,st=x;
	int p=upper_bound(d+a1[x][0],d+a1[x][1]+1,lim)-d,q;
	s+=p-a1[x][0];
	int k=x;
	x=fa[x];
	while(x)
	{
		p=upper_bound(d+a1[x][0],d+a1[x][1]+1,lim-dis[dep[x]][st])-d;
		q=upper_bound(d2+a1[k][0],d2+a1[k][1]+1,lim-dis[dep[x]][st])-d2;
		s+=(p-a1[x][0])-(q-a1[k][0]);
		k=x;
		x=fa[x];
	}
	return s;
}
int main()
{
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	cin>>n>>q;
	int smx=0;
	fo(i,1,n-1)
	{
		int x,y,z;
		scanf("%d%d%d",&x,&y,&z);
		link(x,y,z),link(y,x,z);
		smx+=z;
	}
	mi=n+1;
	dfs(1,0,n);
	dep[mw]=1;
	doit(mw,n);
	fo(i,1,q)
	{
		int x,w;
		scanf("%d%d",&x,&w);
		int l=1,r=smx;
		while(l+1<r)
		{
			int mid=(l+r)>>1;
			if(ct(x,mid)>=w) r=mid;
			else l=mid;
		}
		if(ct(x,l)>=w) r=l;
		printf("%d\n",r);
	}
}

你可能感兴趣的:(题解,————点分治,————二分查找)