nssl1459-空间简单度【扫描线,线段树】

正题


题目大意

n n n个点的一颗树,合法路径定义为一条路径上每个点的编号相差大于 K K K。求合法路径数


解题思路

首先我们可以求不合法的路径数,这样我们就有了 K ∗ n K*n Kn个不合法(即不能在同一个路径上)的点对。

然后这题就和之前一题jzoj6276一样了

大概就是用矩形表示不合法的路径,之后用扫面线求矩形的面积并即可。


c o d e code code

#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
#include
#include
#include
#include
using namespace std;
const int N=3e5+10;
struct node{
	int to,next;
}a[N*2];
struct line{
	int x,l,r,w;
}l[N*40];
bool operator<(line x,line y)
{return x.x<y.x;} 
int n,K,tot,cnt,num;
int rfn[N],ed[N],f[N][21],dep[N];
int w[N*4],mark[N*4],ls[N];
long long ans;
__attribute__((optimize("O3"))) inline int read() {
	int x=0,f=1; char c=getchar();
	while(!isdigit(c)) {if(c=='-')f=-f;c=getchar();}
	while(isdigit(c)) x=(x<<1)+(x<<3)+c-48,c=getchar();
	return x*f;
}
void addl(int x,int y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;
	return;
}
void dfs(int x,int fa){
	rfn[x]=++cnt;
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa)continue;
		dep[y]=dep[x]+1;
		f[y][0]=x;dfs(y,x);
	}
	ed[x]=cnt;
	return;
}
int LCA(int x,int y){
	for(int i=20;i>=0;i--)
		if(dep[f[y][i]]>dep[x])
			y=f[y][i];
	return y;
}
void addc(int x1,int x2,int y1,int y2){
	if(x1>x2)swap(x1,x2);
	if(y1>y1)swap(y1,y2);
	l[++num]=(line){x1,y1,y2,1};
	l[++num]=(line){x2+1,y1,y2,-1};
}
void Ban(int x,int y){
	if(rfn[x]>rfn[y])swap(x,y);
	if(rfn[x]<=rfn[y]&&rfn[y]<=ed[x]){
		int top=LCA(x,y);
		if(rfn[top]!=1)addc(1,rfn[top]-1,rfn[y],ed[y]);
		if(ed[top]!=n)addc(rfn[y],ed[y],ed[top]+1,n);
	}
	else addc(rfn[x],ed[x],rfn[y],ed[y]);
	return;
}
void Change(int x,int L,int R,int l,int r,int val){
	if(L==l&&R==r){
		mark[x]+=val;
		if(mark[x])w[x]=r-l+1;
		else if(l==r)w[x]=0;
		else w[x]=w[x*2]+w[x*2+1];
		return;
	}
	int mid=(L+R)>>1;
	if(r<=mid)Change(x*2,L,mid,l,r,val);
	else if(l>mid)Change(x*2+1,mid+1,R,l,r,val);
	else Change(x*2,L,mid,l,mid,val),Change(x*2+1,mid+1,R,mid+1,r,val);
	if(mark[x])w[x]=R-L+1;
	else w[x]=w[x*2]+w[x*2+1];
	return;
}
int main()
{
	freopen("data.in","r",stdin);
	int size = 256 << 20; //250M
	char*p=(char*)malloc(size) + size;
	__asm__("movl %0, %%esp\n" :: "r"(p) );
	n=read();K=read(); 
	for(int i=1;i<n;i++){
		int x=read(),y=read();
		addl(x,y);addl(y,x);
	}
	dfs(1,1);
	for(int i=1;i<=20;i++)
		for(int j=1;j<=n;j++)
			f[j][i]=f[f[j][i-1]][i-1];
	for(int i=1;i<=n;i++)
		for(int j=i+1;j<=min(i+K,n);j++)
			Ban(i,j);
	sort(l+1,l+1+num);
	int z=1;
	for(int i=1;i<=n;i++){
		while(z<=num&&l[z].x<=i)
			Change(1,1,n,l[z].l,l[z].r,l[z].w),z++;
		ans+=w[1];
	}
	printf("%lld",1ll*n*(n-1)/2-ans+n);
}

你可能感兴趣的:(数据结构,nssl,扫描线,线段树)