codechef Prime Distance On Tree(树分治+FFT)

题目链接:http://www.codechef.com/problems/PRIMEDST/

题意:给出一棵树,边长度都是1。每次任意取出两个点(u,v),他们之间的长度为素数的概率为多大?

思路:树分治。对于u的所有孩子两个两个处理。计算到u的距离,然后用FFT求卷积。枚举素数。

 

struct node

{

    double x,y;



    node(double _x=0.0,double _y=0.0)

    {

        x=_x;

        y=_y;

    }



    node operator+(node a)

    {

        return node(x+a.x,y+a.y);

    }



    node operator-(node a)

    {

        return node(x-a.x,y-a.y);

    }



    node operator*(node a)

    {

        return node(x*a.x-y*a.y,x*a.y+y*a.x);

    }

};



node A[N];

int L;





int reverse(int x)

{

    int ans=0,i;

    FOR0(i,L) if(x&(1<<i)) ans|=1<<(L-1-i);

    return ans;

}





void bitReverseCopy(node a[],int n)

{

    int i;

    FOR0(i,n) A[i]=a[i];

    FOR0(i,n)

    {

        a[reverse(i)]=A[i];

    }

}





void fft(node a[],int n,int on)

{

    bitReverseCopy(a,n);

    int len,i,j,k;

    node x,y,u,t;

    for(len=2;len<=n;len<<=1)

    {

        x=node(cos(-on*2*PI/len),sin(-on*2*PI/len));

        for(j=0;j<n;j+=len)

        {

            y=node(1,0);

            for(k=j;k<j+len/2;k++)

            {

                u=a[k];

                t=y*a[k+len/2];

                a[k]=u+t;

                a[k+len/2]=u-t;

                y=y*x;

            }

        }

    }

    if(on==-1)

    {

        FOR0(i,n) a[i].x/=n;

    }

}







int prime[N],tag[N],cnt;



void init()

{

    tag[0]=1;

    tag[1]=1;

    int i,j;

    for(i=2;i<N;i++) if(!tag[i])

    {

        prime[cnt++]=i;

        for(j=i+i;j<N;j+=i) tag[j]=1;

    }

}





vector<int> g[N];

int n,visit[N];

i64 ans;



int f[N],f1[N];

vector<int> V;



void dfs(int u,int pre)

{

    f[u]=0;

    f1[u]=1;

    int i,v;

    FOR0(i,SZ(g[u]))

    {

        v=g[u][i];

        if(v==pre||visit[v]) continue;

        dfs(v,u);

        f1[u]+=f1[v];

        upMax(f[u],f1[v]);

    }

    V.pb(u);

}



int getRoot(int u)

{

    V.clear(); dfs(u,0);

    int ans=u,p=SZ(V),i,temp,v;

    FOR0(i,SZ(V))

    {

        v=V[i];

        temp=max(SZ(V)-f[v],f[v]);

        if(temp<p) p=temp,ans=v;

    }

    return ans;

}



int d[N];

int MaxDis;

node a[N],b[N];



void DFS1(int u,int pre,vector<int> &V)

{

    d[u]=d[pre]+1; upMax(MaxDis,d[u]);

    V.pb(u);

    int i,v;

    FOR0(i,SZ(g[u]))

    {

        v=g[u][i];

        if(v==pre||visit[v]) continue;



        DFS1(v,u,V);

    }

}





vector<int> V1[N];



int P[N],Q[N];



int M;



void deal()

{

    int i;

    for(i=0;i<=MaxDis;i++) a[i]=node(P[i],0),b[i]=node(Q[i],0);

    while(i<M) a[i]=node(0,0),b[i]=node(0,0),i++;

    fft(a,M,1);

    fft(b,M,1);

    for(i=0;i<M;i++) a[i]=a[i]*b[i];

    fft(a,M,-1);

    i64 temp;

    for(i=0;i<cnt&&prime[i]<M;i++)

    {

        temp=(i64)(a[prime[i]].x+0.5);

        ans+=temp;

    }

}



void DFS(int u)

{

    u=getRoot(u);

    MaxDis=0; d[u]=0;

    int i,j,v,sonNum=0;

    FOR0(i,SZ(g[u]))

    {

        v=g[u][i];

        if(visit[v]) continue;

        V1[sonNum].clear();

        DFS1(v,u,V1[sonNum]);

        sonNum++;

    }

    M=1; L=0;

    while(M<=MaxDis+MaxDis) M<<=1,L++;

    for(i=0;i<=MaxDis;i++) P[i]=Q[i]=0;

    FOR0(i,sonNum)

    {

        FOR0(j,SZ(V1[i]))

        {

            v=V1[i][j];

            Q[d[v]]++;

            if(!tag[d[v]]) ans++;

        }

        deal();

        for(j=0;j<=MaxDis;j++) P[j]+=Q[j],Q[j]=0;

    }

    visit[u]=1;

    FOR0(i,SZ(g[u]))

    {

        v=g[u][i];

        if(!visit[v]) DFS(v);

    }

}



int main()

{

    init();

    Rush(n)

    {

        int i;

        FOR1(i,n) g[i].clear();

        FOR1(i,n-1)

        {

            int u,v;

            RD(u,v);

            g[u].pb(v); g[v].pb(u);

        }

        ans=0; DFS(1);

        printf("%.8lf\n",ans/(1.0*n*(n-1)/2));

    }

}

  

你可能感兴趣的:(code)