【点分治+拆系数FFT】CodeChef - CUTTREE (Chef Cuts Tree )

【题目】
原题地址
题目大意:定义森林的强度为连通块大小的平方和。
第0天有一棵 n n 个节点的树,每一天随机删掉森林中的一条边,共进行 n1 n − 1 天,对于每一天求出森林强度的期望值。答案模 109+7 10 9 + 7

【题目分析】
根据套路,我们一般是讨论树上每个点对对答案的贡献,这样一般可以用点分治来做。

【解题思路】
对于一个有序点对 (x,y) ( x , y ) ,若在第 i i 天结束时这两点连通,则会对森林的强度作出1的贡献。那么我们要求的实际上就是每天期望联通的点对数量。
dis(x,y) d i s ( x , y ) 表示 x x y y 的距离,则 (x,y) ( x , y ) 对第 i i 天的贡献为:

(n1dis(x,y))i(n1)i ( n − 1 − d i s ( x , y ) ) i ( n − 1 ) i

即一个联通块的价值可以理解为任意两点连通所以经过的边数。

如果我们预处理出 cnti c n t i 表示距离为 i i 的点对有多少对,那么

ansi=d=0n1(nd1)!(ni1)!cntd(ndi1)!(n1)! a n s i = ∑ d = 0 n − 1 ( n − d − 1 ) ! ∗ ( n − i − 1 ) ! ∗ c n t d ( n − d − i − 1 ) ! ( n − 1 ) !

可以发现处理出来后就是一个卷积的形式,用FFT优化即可。
同理求 cnt c n t 也是可以用FFT的。

由于模数不是NTT模数,所以要上任意拆系数FFT。(三模数NTT应该会TLE,因为我拆系数FFT跑了1.1s)
然后要用long double,不然只能过25%的数据。

【参考代码】

#include
#define mkp(x,y) make_pair(x,y)
using namespace std;

typedef long long LL;
typedef long double ldb;
typedef pair<int,int> pii;
const int N=4e5+10;
const int mod=1e9+7;
const int M=32768;
const ldb pi=acos(-1);
int n,m,L,tot,root,mx,sum;
int son[N],siz[N],rev[N],vis[N];
int head[N],cnt[N],dis[N];
LL inv[N],fac[N],ans[N],dep[N];

LL read()
{
    LL ret=0,f=1;char c=getchar();
    while(!isdigit(c)){if(c=='-')f=0;c=getchar();}
    while(isdigit(c)){ret=(ret<<1ll)+(ret<<3ll)+(c^48);c=getchar();}
    return f?ret:-ret;
}

struct Tway
{
    int v,nex;
};
Tway es[N<<1];

void add(int u,int v)
{
    es[++tot]=(Tway){v,head[u]};head[u]=tot;
    es[++tot]=(Tway){u,head[v]};head[v]=tot;
}

struct E
{
    ldb r,i;
    E(){};
    E(ldb rx,ldb ix){r=rx;i=ix;}

    E operator + (const E&A)const{
        return E(r+A.r,i+A.i);
    }

    E operator - (const E&A)const{
        return E(r-A.r,i-A.i);
    }

    E operator * (const E&A)const{
        return E(r*A.r-i*A.i,r*A.i+i*A.r);
    }
};
E da,db,dc,dd;
E a[N],b[N],e[N],f[N],g[N],h[N];

E conj(E a) {return E(a.r,-a.i);}

void fft(E *a,int n,int f)
{
    for(int i=0;iif(ifor(int i=1;i1)
    {
        E wn=E(cos(pi/i),f*sin(pi/i));
        for(int j=0;j1))
        {
            E w=E(1,0);
            for(int k=0;k*wn)
            {
                E x=a[j+k],y=w*a[i+j+k];
                a[j+k]=x+y;a[i+j+k]=x-y;
            }
        }   
    }
    if(f==-1)
        for(int i=0;iint x,int f)
{
    siz[x]=1;son[x]=0;
    for(int i=head[x];i;i=es[i].nex)
    {
        int v=es[i].v;
        if(vis[v] || v==f)
            continue;
        getroot(v,x);
        siz[x]+=siz[v];
        son[x]=max(son[x],siz[v]);
    }
    son[x]=max(son[x],sum-siz[x]);
    if(son[x]x;
}

void getdep(int x,int f,int dp)
{
    cnt[dp]++;siz[x]=1;mx=max(mx,dp);
    for(int i=head[x];i;i=es[i].nex)
    {
        int v=es[i].v;
        if(vis[v] || v==f)
            continue;
        getdep(v,x,dp+1);
        siz[x]+=siz[v];
    }
}

void calc(int mx,int f)
{
    L=0;m=1;
    for(;m<=mx*2;m<<=1,++L);
    for(int i=0;i<m;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));

    for(int i=0;i<=mx;++i)
        a[i]=E(cnt[i],0);
    for(int i=mx+1;i<m;++i)
        a[i]=E(0,0);
    fft(a,m,1);
    for(int i=0;i<m;++i)
        a[i]=a[i]*a[i];
    fft(a,m,-1);
    for(int i=1;i<=min(n-1,mx*2);++i)
        (dis[i]+=(LL)(a[i].r+0.5)%mod*f)%=mod;

}

void solve(int x)
{
    vis[x]=1;mx=0;getdep(x,0,0);calc(mx,1);
    for(int i=0;i<=mx;++i)
        cnt[i]=0;
    for(int i=head[x];i;i=es[i].nex)
    {
        int v=es[i].v;
        if(vis[v])
            continue;
        mx=0;getdep(v,x,1);calc(mx,-1);
        for(int j=0;j<=mx;++j)
            cnt[j]=0;
    }
    for(int i=head[x];i;i=es[i].nex)
    {
        int v=es[i].v;
        if(vis[v])
            continue;
        sum=siz[v];root=0;getroot(v,x);
        solve(root);
    }
}

void prepare()
{
    fac[0]=fac[1]=inv[0]=inv[1]=1;
    for(int i=2;i<=n;++i)
        fac[i]=1ll*fac[i-1]*i%mod,inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    for(int i=2;i<=n;++i)
        inv[i]=1ll*inv[i-1]*inv[i]%mod;
    dis[0]=n;
    for(int i=1;i%mod;

    L=0;m=1;
    for(;m<=n*2;m<<=1,++L);
    for(int i=0;i<m;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
}

void calc_ans()
{
    prepare();
    for(int i=0;i<=m;++i)
        a[i]=b[i]=E(0,0);
    for(int i=0;i1ll*dis[i]*fac[n-i-1]%mod/M,1ll*dis[i]*fac[n-i-1]%mod%M);    
        b[i]=E(inv[i]/M,inv[i]%M);
    }
    fft(a,m,1);fft(b,m,1);

    for(int i=0;i<m;++i)
    {
        int j=(m-i)&(m-1);
        da=(a[i]+conj(a[j]))*E(0.5,0);db=(a[i]-conj(a[j]))*E(0,-0.5);
        dc=(b[i]+conj(b[j]))*E(0.5,0);dd=(b[i]-conj(b[j]))*E(0,-0.5); 
        e[i]=da*dc;f[i]=da*dd;
        g[i]=db*dc;h[i]=db*dd;
    }
    for(int i=0;i<m;++i)
        a[i]=e[i]+f[i]*E(0,1),b[i]=g[i]+h[i]*E(0,1); 
    fft(a,m,-1);fft(b,m,-1);

    for(int i=0;i0.5)%mod*M%mod*M%mod;//add 0.5!
        ans[i]+=(LL)(a[i].i/m+0.5)%mod*M%mod;
        ans[i]+=(LL)(b[i].r+0.5)%mod*M%mod;
        ans[i]+=(LL)(b[i].i/m+0.5)%mod;
        ans[i]%=mod;
    }

    printf("%d ",1ll*n*n%mod);
    for(int i=1;iprintf("%d ",1ll*ans[n-i-1]*fac[n-i-1]%mod*inv[n-1]%mod);
}

int main()
{
    freopen("CC_CUTTREE.in","r",stdin);
    freopen("CC_CUTTREE.out","w",stdout);

    n=read();
    for(int i=1;iint u=read(),v=read();
        add(u,v);
    }
    son[0]=n;sum=n;root=0;getroot(1,0);
    solve(root);
    calc_ans();

    return 0;
}

【总结】
这题主要在于考虑如何求出期望值,这个树上期望的转换应该是很常用的。
然后的就是写代码的细节问题了。

你可能感兴趣的:(分而治之-树分治,数论-FFT/NTT)