n n n个点的一颗树,合法路径定义为一条路径上每个点的编号相差大于 K K K。求合法路径数
首先我们可以求不合法的路径数,这样我们就有了 K ∗ n K*n K∗n个不合法(即不能在同一个路径上)的点对。
然后这题就和之前一题jzoj6276一样了
大概就是用矩形表示不合法的路径,之后用扫面线求矩形的面积并即可。
#pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
#include
#include
#include
#include
using namespace std;
const int N=3e5+10;
struct node{
int to,next;
}a[N*2];
struct line{
int x,l,r,w;
}l[N*40];
bool operator<(line x,line y)
{return x.x<y.x;}
int n,K,tot,cnt,num;
int rfn[N],ed[N],f[N][21],dep[N];
int w[N*4],mark[N*4],ls[N];
long long ans;
__attribute__((optimize("O3"))) inline int read() {
int x=0,f=1; char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-f;c=getchar();}
while(isdigit(c)) x=(x<<1)+(x<<3)+c-48,c=getchar();
return x*f;
}
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;
return;
}
void dfs(int x,int fa){
rfn[x]=++cnt;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
dep[y]=dep[x]+1;
f[y][0]=x;dfs(y,x);
}
ed[x]=cnt;
return;
}
int LCA(int x,int y){
for(int i=20;i>=0;i--)
if(dep[f[y][i]]>dep[x])
y=f[y][i];
return y;
}
void addc(int x1,int x2,int y1,int y2){
if(x1>x2)swap(x1,x2);
if(y1>y1)swap(y1,y2);
l[++num]=(line){x1,y1,y2,1};
l[++num]=(line){x2+1,y1,y2,-1};
}
void Ban(int x,int y){
if(rfn[x]>rfn[y])swap(x,y);
if(rfn[x]<=rfn[y]&&rfn[y]<=ed[x]){
int top=LCA(x,y);
if(rfn[top]!=1)addc(1,rfn[top]-1,rfn[y],ed[y]);
if(ed[top]!=n)addc(rfn[y],ed[y],ed[top]+1,n);
}
else addc(rfn[x],ed[x],rfn[y],ed[y]);
return;
}
void Change(int x,int L,int R,int l,int r,int val){
if(L==l&&R==r){
mark[x]+=val;
if(mark[x])w[x]=r-l+1;
else if(l==r)w[x]=0;
else w[x]=w[x*2]+w[x*2+1];
return;
}
int mid=(L+R)>>1;
if(r<=mid)Change(x*2,L,mid,l,r,val);
else if(l>mid)Change(x*2+1,mid+1,R,l,r,val);
else Change(x*2,L,mid,l,mid,val),Change(x*2+1,mid+1,R,mid+1,r,val);
if(mark[x])w[x]=R-L+1;
else w[x]=w[x*2]+w[x*2+1];
return;
}
int main()
{
freopen("data.in","r",stdin);
int size = 256 << 20; //250M
char*p=(char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p) );
n=read();K=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
addl(x,y);addl(y,x);
}
dfs(1,1);
for(int i=1;i<=20;i++)
for(int j=1;j<=n;j++)
f[j][i]=f[f[j][i-1]][i-1];
for(int i=1;i<=n;i++)
for(int j=i+1;j<=min(i+K,n);j++)
Ban(i,j);
sort(l+1,l+1+num);
int z=1;
for(int i=1;i<=n;i++){
while(z<=num&&l[z].x<=i)
Change(1,1,n,l[z].l,l[z].r,l[z].w),z++;
ans+=w[1];
}
printf("%lld",1ll*n*(n-1)/2-ans+n);
}