http://blog.csdn.net/howarli/article/details/51398321
抽象题意:求一个环套外向树中距离小于等于k的点有多少对
很显然,第一第二问是一样的,
在考场上想了半天先做子树再做环,想了半天想不出,却有了另一个想法:
我们先把环去掉一条边,这题就变成了一道经典的点分治,
统计完以后,我们再把边加上来,
现在问题变成了:求有多少对点过环边距离<=k,不过环边距离>k。
随便拉一个环上的点作为根,距离指于根节点的距离,我们现在统计的是经过根和环边的点对(也就是走到根节点又走环边下来)
设点x不经过环边与根的距离为 cx ,经过环边与根的距离为 Cx
首先, cx<=Cx 的点是不被考虑的,因为过了环距离也不能达到最优,肯定不走啦,而且这样还可以除掉那些沿着环绕了一圈又回来的点对,
我们先算环上的点:
设c=q的点有 sq 个,C=q的点有 Sq 个
当前点为x,答案就是:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fod(i,a,b) for(int i=a;i>=b;i--)
#define efo(i,q) for(int i=A[q];i;i=B[i][0])
#define NX(q) ((q)&(-(q)))
using namespace std;
typedef long long LL;
const int N=100500,maxlongint=2147483640;
int read(int &n)
{
char ch=getchar();
while((ch!='-')&&((ch<'0')||(ch>'9')))ch=getchar();
int q=0,w=1;if(ch=='-')w=-1,ch=getchar();
while(ch>='0' && ch<='9')q=q*10+ch-48,ch=getchar();n=q*w;return n;
}
int m,n,ce,ces,alln;
LL ans,ans1;
int b[N],d[N];
int B[N*2][3],A[N],B0=1;
struct qqww
{int c,n,d,g;}a[N];
int c[N],c1[N],C[N*2][2],CA[N];
LL cb[N],c1b[N];
bool z[N];
int HU[N],HB,HZ[N];
int D[N*2][2],D0,DA[N];
int min(int a,int b){return a>b?b:a;}
int max(int a,int b){return a>b?a:b;}
void join(int q,int w)
{
B[++B0][0]=A[q],A[q]=B0,B[B0][1]=w;
B[++B0][0]=A[w],A[w]=B0,B[B0][1]=q;
}
int tarjan(int q,int e,int fa)
{
if(c[q])return c[q];
c[q]=e;int w;
efo(i,q)if(B[i][1]!=fa)
{
w=tarjan(B[i][1],e+1,q);
if(w<=e)
{
if(!HU[0])B[HB=i][2]=B[HB^1][2]=1;
HZ[HU[++HU[0]]=q]=1;return w;
}
if(HU[0])return e;
}
return e;
}
int dfsf(int q,int fa)
{
a[q].n=1;
efo(i,q)if(!z[B[i][1]]&&B[i][1]!=fa&&(!B[i][2]))a[q].n+=dfsf(B[i][1],q);
return a[q].n;
}
void findc(int q,int fa)
{
int mx=alln-a[q].n;
efo(i,q)if(!z[B[i][1]]&&B[i][1]!=fa&&(!B[i][2]))findc(B[i][1],q),mx=max(mx,a[B[i][1]].n);
if(ces>mx)ce=q,ces=mx;
}
int dfsd(int q,int e,int fa)
{
int w=e;c1[e]++;c1b[e]+=(LL)b[q];
efo(i,q)if(!z[B[i][1]]&&B[i][1]!=fa&&(!B[i][2]))w=max(dfsd(B[i][1],e+1,q),w);
return w;
}
void divide(int q)
{
int I=1,J=1;
d[1]=q;
while(I<=J)
{
q=d[I];
alln=dfsf(q,q);
ces=maxlongint;
findc(q,q);
z[q=ce]=1;
fill(c,c+alln+2,0);fill(cb,cb+alln+2,0);
efo(i,q)if(!z[B[i][1]]&&(!B[i][2]))
{
int w=min(dfsd(B[i][1],1,q),min(m,alln));LL e=0,eb=0;
fo(j,1,min(m-w-1,alln))e+=c[j],eb+=cb[j];
fod(j,w,1)e+=(m-j<=alln?c[m-j]:0),eb+=(m-j<=alln?cb[m-j]:0),ans+=e*c1[j],ans1+=eb*c1b[j];
fo(j,1,w)c[j]+=c1[j],cb[j]+=c1b[j],c1[j]=c1b[j]=0;
}
fo(i,1,min(m,alln))ans+=(LL)c[i],ans1+=(LL)cb[i]*b[q];
efo(i,q)if(!z[B[i][1]]&&(!B[i][2]))d[++J]=B[i][1];
I++;
}
}
void joinc(int q,int w){C[++B0][0]=CA[q],CA[q]=B0,C[B0][1]=w;}
void joind(int q,int w){D[++D0][0]=DA[q],DA[q]=D0,D[D0][1]=w;}
void dfsyc(int q,int g,int e,int fa)
{
a[q].g=g;a[q].c=e,joind(e,q);
efo(i,q)if((!HZ[B[i][1]])&&(B[i][1]!=fa)&&(!B[i][2]))dfsyc(B[i][1],g,e+1,q);
}
void dfshy(int q,int e,int fa)
{
if(a[q].c>e)z[q]=1,joinc(e,q);a[q].d=e;
efo(i,q)if(B[i][1]!=fa&&(!B[i][2]))dfshy(B[i][1],e+1,q);
}
void addbst(int q,int q1,LL bs)
{
while(q<=n)c[q]++,cb[q]+=bs,q+=NX(q);
while(q1<=m)c1[q1]++,c1b[q1]+=bs,q1+=NX(q1);
}
void finda(int q,int hc,LL bs)
{
if(hc<0)return;hc=min(hc,n);
while(q)ans+=c1[q],ans1+=c1b[q]*bs,q-=NX(q);
while(hc)ans-=c[hc],ans1-=cb[hc]*bs,hc-=NX(hc);
}
void spjdo(int q)
{
memset(c,0,sizeof(c));memset(cb,0,sizeof(cb));
memset(c1,0,sizeof(c1));memset(c1b,0,sizeof(c1b));B0=1;
fod(i,HU[0],1) dfsyc(HU[i],HU[i],HU[0]-i+1,0);
dfshy(HU[1],1,q);
fo(i,1,m)
{
int w=m-i;
for(int j=CA[i];j;j=C[j][0])addbst(a[C[j][1]].c,a[C[j][1]].d,b[C[j][1]]);
for(int j=DA[w+1];j;j=D[j][0])
{
int q1=D[j][1];
finda(i,m-(a[q1].c-a[a[q1].g].c)+a[a[q1].g].c,b[q1]);
}
}
}
int main()
{
freopen("pronet.in","r",stdin);
freopen("pronet.out","w",stdout);
int q,w;
read(n);read(m);
fo(i,1,n) read(q),join(q,i);
fo(i,1,n)read(b[i]);
tarjan(1,1,1);
memset(c,0,sizeof(c));
divide(1);
if(HU[0])spjdo(B[HB][1]);
printf("%lld %lld\n",ans,ans1);
return 0;
}