CodeChef-PRIMEDST
采用点分治统计所有的路径,但是如果是朴素的路径相乘复杂度是O(n^2),所以采用FFT加速多项式乘法达到(nlogn)的复杂度,总体复杂度就是(nlognlogn)的复杂度。
做题步骤:
#include
using namespace std;
#define ll long long
#define pi acos(-1)
const int N = 1e5+10;
int head[N],tot;
ll sum,rt,dis[N],tmp[N<<4],maxp[N],siz[N],len,d[N],up,dn;
int f[N<<2],g[N<<2];
int prime[N];
bool isprime[N],pcnt;
ll ans[N];
int cnt,n;
bool vis[N];
int r[N];
void init(){
isprime[0]=isprime[1]=1;
for(int i=2;i<N;i++){
if(!isprime[i]){
for(int j=i+i;j<N;j+=i){
isprime[j]=1;
}
}
}
}
struct cp
{
double r,i;
cp(double _r = 0,double _i = 0)
{
r = _r; i = _i;
}
cp operator +(const cp &b)
{
return cp(r+b.r,i+b.i);
}
cp operator -(const cp &b)
{
return cp(r-b.r,i-b.i);
}
cp operator *(const cp &b)
{
return cp(r*b.r-i*b.i,r*b.i+i*b.r);
}
}F[N<<2],G[N<<2],H[N<<2];
void FFT(cp *s,int n,int inv){
int bit=0;
while ((1<<bit)<n)bit++;
for(int i=0;i<n;i++)r[i]=i;
for(int i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=0;i<n;i++){
if(i<r[i])swap(s[i],s[r[i]]);
}
for(int len=2;len<=n;len*=2){
cp wn = cp(cos(pi*2.0/len),inv*sin(pi*2.0/len));
for(int st=0;st<n;st+=len){
cp w = cp(1.0,0.0);
for(int i=st;i<st+len/2;i++,w=w*wn){
cp x=s[i];//偶次幂
cp y=w*s[i+len/2];//奇次幂
s[i]=x+y;
s[i+len/2]=x-y;
}
}
}
}
struct edge{
int v,w,nex;
edge() {}
edge(int v, int w, int nex) : v(v), w(w), nex(nex) {}
}edges[N<<1];
void add(int u,int v,int w){
edges[tot]=edge(v,w,head[u]);
head[u]=tot++;
}
void getrt(int u,int fa){
maxp[u]=0,siz[u]=1;
for(int i=head[u];~i;i=edges[i].nex){
int v = edges[i].v;
if(vis[v]||v==fa)continue;
getrt(v,u);
siz[u]+=siz[v];
if(maxp[u]<siz[v])maxp[u]=siz[v];
}
maxp[u]=max(maxp[u],sum-siz[u]);
if(maxp[u]<maxp[rt])rt=u;
}
void getdis(int u,int fa){
tmp[cnt++]=dis[u];
for(int i=head[u];~i;i=edges[i].nex){
int v =edges[i].v;
if(vis[v]||v==fa)continue;
dis[v]=dis[u]+edges[i].w;
getdis(v,u);
}
}
void getdeep(int u,int fa,int deep,int &mx){
mx = max(mx,deep);
for(int i=head[u];~i;i=edges[i].nex){
int v = edges[i].v;
if(vis[v]||v==fa)continue;
getdeep(v,u,deep+1,mx);
}
}
void solve(int u){
int mx = -1;
getdeep(u,0,0,mx);
len = 1;
while(len<=2*mx)len<<=1;
for(int i=head[u];~i;i=edges[i].nex){
int v = edges[i].v;
if(vis[v])continue;
dis[v]=edges[i].w;
cnt = 0;
getdis(v,u);
for(int j=0;j<cnt;j++){
g[tmp[j]]++;
if(!isprime[tmp[j]])up++;
}
int bit=0;
while ((1<<bit)<len)bit++;
for(int j=0;j<len;j++){
G[j].r=g[j],G[j].i=0;
F[j].r=f[j],F[j].i=0;
}
FFT(F,len,1);
FFT(G,len,1);
for(int j=0;j<len;j++)H[j]=F[j]*G[j];
FFT(H,len,-1);
for(int j=0;j<len;j++){
if(!isprime[j])
up+=(ll)(H[j].r/len+0.5);
}
for(int j=0;j<cnt;j++)f[tmp[j]]++,g[tmp[j]]=0;
}
for(int i=0;i<=n;i++)f[i]=0;
}
void divide(int u){
vis[u]=1;
solve(u);
for(int i=head[u];~i;i=edges[i].nex){
int v = edges[i].v;
if(vis[v])continue;
maxp[rt=0]=sum=siz[v];
getrt(v,0);
getrt(rt,0);
divide(rt);
}
}
int main(){
init();
memset(head,-1,sizeof(head));
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%d %d",&u,&v);
add(u,v,1);
add(v,u,1);
}
maxp[0]=sum=n;
getrt(1,0);
getrt(rt,0);
divide(rt);
dn = 1ll*n*(n-1)/2;
printf("%.6lf\n",(double)up/(double)dn);
}