首先我们知道,这个题可以N^2的做,我们先确定一个根,然后讨论下情况,合法的三个点只可能有三种情况,第一种是三个点有相同的lca,这种情况我们可以用tree-dp来解决,用dis[i][j]表示i为根的子树中距离i为j的点的数量,然后我们枚举儿子,处理出dis_[i][j]表示i子树中选两个点距离i为j的点对儿数量,且在不同子树中(到i路径无重合)。那么枚举儿子的时候可以处理出这种情况的答案。
还有一种情况为在i的子树中选2个不同子树中的点,也是就dis_[i][j],然后第三个点不在i的子树中,那么我们枚举每一个点,bfs出距离这个点距离为j的且不在i的子树中的点的数量为fuck_dis[i][j],那么ans+=dis_[i][j]*fuck_dis[i][j]。因为每个节点只会有一个父亲,所以最多只能向上引一条不重叠的链,保证了算法的正确性。
反思:这道题不算特别难,但是写了挺长时间,开始的时候之想到了第一中情况,第二种情况只考虑了对于两个点距离i为j,i的一个祖先距离i为j,然后这个祖先可以成为答案,就用i的dep来判一下就好了。但是还有其他的可能,比如在某个地方拐一下,这个点到i的距离为j,所以我又写了每个点到除去子树中的点的距离最远的点,大于了j就可以更新答案,但是没有考虑到这个点有多种选法,所以最后写的距离i为j的点且不在i的子树中的数量,这个开始我想的是tree-dp,想了想觉得很不好写,总之tree-dp也是n^2,所以就写了每个点的bfs。
备注:还有另外一种方法,其实我们第一种方法的本质就是讨论,我们可以发现所有满足条件的点对儿都可以通过选取不同的根来看成三个点有一个lca,那么我们可以枚举所有的点,然后bfs,每做一层我们都统计一下答案,方法和上一种方法的第一种情况类似,这种方法可以节省空间,但是觉得不是特别好写。
这道题没有链的数据,我用dis[i][j]存的每个点的信息,但是n^2的内存会超,所以就估了一下不会超过2500左右,然后就把数组开小了些,其实应该用vector来存或者最开始选根的时候选重心来保证这个,但是懒得写了= =。
/************************************************************** Problem: 3522 User: BLADEVIL Language: C++ Result: Accepted Time:2620 ms Memory:118816 kb ****************************************************************/ //By BLADEVIL #include <cstdio> #include <cstring> #include <algorithm> #define maxn 5010 #define inf (~0U>>1) using namespace std; int n,l; long long ans; int pre[maxn<<1],other[maxn<<1],last[maxn]; int que[maxn],dep[maxn],dis[maxn][2000],dis_[maxn][2000],que_[maxn],dep_[maxn],father[maxn]; int jump[maxn][20],fuck_dis[maxn][2000]; void connect(int x,int y) { pre[++l]=last[x];; other[l]=y; last[x]=l; } int main() { //freopen("hot.in","r",stdin); freopen("hot.out","w",stdout); scanf("%d",&n); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); connect(x,y); connect(y,x); } int h=0,t=1; que[1]=1; dep[1]=1; while (h<t) { int cur=que[++h]; for (int p=last[cur];p;p=pre[p]) { if (dep[other[p]]) continue; que[++t]=other[p]; dep[other[p]]=dep[cur]+1; father[other[p]]=cur; } } //for (int i=1;i<=n;i++) printf("%d ",standard[i]); //for (int i=1;i<=n;i++) printf("%d ",que[i]);printf("\n"); for (int i=n;i;i--) { int cur=que[i],sum=0; for (int p=last[cur];p;p=pre[p]) { if (dep[other[p]]<dep[cur]) continue; for (int j=1;j<=dis[other[p]][0];j++) ans+=dis_[cur][j+1]*dis[other[p]][j]; ans+=dis_[cur][1]; for (int j=1;j<=dis[other[p]][0];j++) dis_[cur][j+1]+=dis[cur][j+1]*dis[other[p]][j]; dis_[cur][1]+=sum; dis[cur][0]=max(dis[cur][0],dis[other[p]][0]+1); for (int j=1;j<=dis[other[p]][0];j++) dis[cur][j+1]+=dis[other[p]][j]; dis[cur][1]++; sum++; } //printf("|%d %d\n",cur,dis[cur][1]); //printf("||%d %d\n",cur,ans); } for (int i=2;i<=n;i++) { memset(dep_,0,sizeof dep_); memset(que_,0,sizeof que_); int h=0,t=1; que_[1]=father[i]; dep_[father[i]]=1; dep_[i]=-1; while (h<t) { int cur=que_[++h]; for (int p=last[cur];p;p=pre[p]) { if ((dep_[other[p]])||(dep_[other[p]]==-1)) continue; que_[++t]=other[p]; dep_[other[p]]=dep_[cur]+1; } } for (int j=1;j<=n;j++) fuck_dis[i][dep_[j]]++; } /* for (int i=1;i<=n;i++) { printf("%d ",i); for (int j=1;j<=n;j++) if (fuck_dis[i][j]) printf("%d ",fuck_dis[i][j]); printf("\n"); } */ for (int i=1;i<=n;i++) for (int j=1;j<=dis[i][0];j++) ans+=dis_[i][j]*fuck_dis[i][j]; printf("%lld\n",ans); fclose(stdin); fclose(stdout); return 0; }