传送门:bzoj4543
f [ i ] [ j ] f[i][j] f[i][j]表示 i i i子树中与 i i i距离为 j j j(相对深度为 j j j)的结点个数,
g [ i ] [ j ] g[i][j] g[i][j]表示 i i i子树中点对 ( x , y ) (x,y) (x,y),满足 d e p x = d e p y , d i s ( l c a ( x , y ) , i ) + j = d i s ( x , l c a ( x , y ) ) = d i s ( y , l c a ( x , y ) ) dep_x=dep_y,dis(lca(x,y),i)+j=dis(x,lca(x,y))=dis(y,lca(x,y)) depx=depy,dis(lca(x,y),i)+j=dis(x,lca(x,y))=dis(y,lca(x,y))的对数(如图中点对(1,2))。
对于点 x x x,统计其相当于图中B点(三元组 L C A LCA LCA)时的贡献:
首先 a n s + = g [ x ] [ 0 ] ans+=g[x][0] ans+=g[x][0],假设每次新加入的儿子结点为 y y y,则遍历所有合法深度 k k k, a n s + = g [ y ] [ k − 1 ] × f [ x ] [ k ] + g [ x ] [ k ] × f [ y ] [ k − 1 ] ans+=g[y][k-1]\times f[x][k]+g[x][k]\times f[y][k-1] ans+=g[y][k−1]×f[x][k]+g[x][k]×f[y][k−1]
转移:
g [ x ] [ k ] + = f [ x ] [ k ] ∗ f [ y ] [ k − 1 ] g[x][k]+=f[x][k]*f[y][k-1] g[x][k]+=f[x][k]∗f[y][k−1]
g [ x ] [ k ] + = g [ y ] [ k + 1 ] g[x][k]+=g[y][k+1] g[x][k]+=g[y][k+1]
f [ x ] [ k ] + = f [ y ] [ k − 1 ] f[x][k]+=f[y][k-1] f[x][k]+=f[y][k−1]
发现每次 k → k − 1 / k + 1 k\rightarrow k-1/k+1 k→k−1/k+1的变换可以用指针数组位移来 O ( 1 ) O(1) O(1)实现。
考虑长链剖分,每次继承重儿子的信息,暴力合并轻儿子。
时间复杂度 O ( n ) O(n) O(n)
#include
using namespace std;
const int N=1e5+20;
typedef long long ll;
int n,m,mxdp[N],d[N],son[N];
int head[N],to[N<<1],nxt[N<<1],tot;
ll ans,a[N*10],*nw=a;
ll *f[N],*g[N];
inline void lk(int u,int v)
{to[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void dfs(int x,int fr)
{
int i,j,k;
mxdp[x]=x;d[x]=d[fr]+1;
for(i=head[x];i;i=nxt[i]) if((j=to[i])!=fr){
dfs(j,x);
if(d[mxdp[j]]>d[mxdp[x]]) mxdp[x]=mxdp[j],son[x]=j;
}
for(i=head[x];i;i=nxt[i]){
j=to[i];if(j==fr || ((x!=1) && j==son[x])) continue;
j=mxdp[j];k=d[j]-d[x];
f[j]=(nw+=(k+1));g[j]=(++nw);
nw+=((k<<1)+5);
}
}
void dp(int x,int fr)
{
int i,j,k,t;
if(son[x]) {dp(son[x],x);f[x]=f[son[x]]-1;g[x]=g[son[x]]+1;}
ans+=g[x][0];f[x][0]=1;
for(i=head[x];i;i=nxt[i]){
j=to[i];if(j==fr || j==son[x]) continue;
dp(j,x);t=d[mxdp[j]]-d[x];
for(k=0;k<=t;++k) ans+=g[x][k+1]*f[j][k]+f[x][k-1]*g[j][k];
for(k=0;k<=t;++k){
g[x][k-1]+=g[j][k];
g[x][k+1]+=f[x][k+1]*f[j][k];
f[x][k+1]+=f[j][k];
}
}
}
int main(){
int i,x,y;
scanf("%d",&n);
for(i=1;i<n;++i){scanf("%d%d",&x,&y);lk(x,y);lk(y,x);}
dfs(1,0);dp(1,0);
printf("%lld",ans);
return 0;
}