一颗n个节点的树。
找三个不同编号的节点,使它们两两间距离相同(一条边距离视作1),求方案数。
在3522的版本中,n<=5000
在4543的版本中,n<=100000
我们来考虑DP
用f[i,j]表示以i为根的子树里与i距离为j的点的个数。
g[i,j]表示在以i为根的子树里,有多少对(x,y)满足x与y到它们lca的距离均为d,且i到它们的lca距离为d-j(容易看出第三个不在i子树内的与i距离为j的点能与这些点匹配成合法解)
接下来用x表示当前节点,y表示一个子节点。f与g都是实时更新的,表示当前做掉的儿子的信息,然后加入一个新的儿子的信息。
一开始先考虑从某一个儿子转移过来,然后此时统计有至少一个点在该儿子子树内时的答案,那么一定是有两个点在该儿子子树内,第三个点就是x。
此时答案就是g[x][0]。
然后接下来枚举其他儿子。
转移式是:
1、f[x,i]+=f[y,i-1]
2、g[x,i-1]+=g[y,i]
3、g[x,i+1]+=f[x,i+1]*f[y,i]
都比较好理解
每做完一个儿子,还要统计答案,就是三个点至少有一个但不是全部在这个新儿子子树里的答案个数。
ans+=f[x,i-1]*g[y,i]+g[x,i+1]*f[y,i]
分别是一个在儿子子树内和两个在儿子子树内的答案。
这样我们是n^2的。
空间当然也是n^2的……
其实想做n^2也可以完全不用这么做,可以通过一些dp得到一个点子树里距离它为i的点对数,然后再dp出不在一个点子树里与它距离为i的点个数,于是就可以算了。
接下来4543的算法在3522的算法上进行改进。
我们考虑……那个叫长链剖分吗?
就是选择一个深度最大的儿子当重儿子,一个点与重儿子间连的边叫重边,非重儿子称为轻儿子,非重边称为轻边。
从3522的算法可以看出,如果我们用指针来实现,一开始对一个儿子的信息进行位移(f[x][i]=f[y][i-1]和g[x][i]=g[y][i+1]相当于位移一位)可以O(1)实现!
我们不妨把选择的这个儿子就钦点为重儿子,那么接下来只需要对轻边做转移。
设mx[i]表示i子树内深度最大点。
那么一条轻边x到y转移的复杂度是O(dep[mx[y]]-dep[x])
深度最大点肯定是个叶子,不难看出,dep[mx[y]]-dep[x]正好是mx[y]所在重链的长度!
因此转移的总复杂度就是重链长度和,为n。
那么这个算法是线性的!
至于空间分配,当然也只需要o(n)的空间。
给每条重链分配正比于重链长度的空间即可。
具体指针分配空间可看代码实现。
因为我之前也不会这玩意所以我代码基本就是抄的,感谢neither_nor大爷的代码
#include
#include
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int maxn=100000+10;
int h[maxn],go[maxn*2],next[maxn*2];
ll xdl[maxn*5];
ll *f[maxn],*g[maxn];
int mx[maxn],dep[maxn];
ll *gjx=xdl+5;
int i,j,k,l,t,n,m,tot;
ll ans;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void dfs(int x,int F){
int i,t,y;
mx[x]=x;
t=h[x];
while (t){
y=go[t];
if (y!=F){
dep[y]=dep[x]+1;
dfs(y,x);
if (dep[mx[y]]>dep[mx[x]]) mx[x]=mx[y];
}
t=next[t];
}
t=h[x];
while (t){
y=go[t];
if (y!=F&&(mx[x]!=mx[y]||x==1)){
gjx+=dep[mx[y]]-dep[x]+1;
f[mx[y]]=gjx;
g[mx[y]]=(gjx+=1);
gjx+=(dep[mx[y]]-dep[x])*2+1;
}
t=next[t];
}
}
void dp(int x,int F){
int i,j,t,y;
t=h[x];
while (t){
y=go[t];
if (y!=F){
dp(y,x);
if (mx[y]==mx[x]){
f[x]=f[y]-1;
g[x]=g[y]+1;
}
}
t=next[t];
}
f[x][0]=1;
ans+=g[x][0];
t=h[x];
while (t){
y=go[t];
if (y!=F&&mx[x]!=mx[y]){
fo(j,0,dep[mx[y]]-dep[x]) ans+=f[x][j-1]*g[y][j]+g[x][j+1]*f[y][j];
fo(j,0,dep[mx[y]]-dep[x]){
f[x][j]+=f[y][j-1];
g[x][j-1]+=g[y][j];
g[x][j+1]+=f[x][j+1]*f[y][j];
}
}
t=next[t];
}
}
int main(){
n=read();
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
dep[1]=1;
dfs(1,0);
dp(1,0);
printf("%lld\n",ans);
}