Codechef CUTTREE 点分治+任意模数FFT

题意

定义森林的强度为连通块大小的平方和。
第0天有一棵n个节点的树,每一天大厨会随机删掉森林中的一条边,共进行n-1天,对于 i=0,1,...,n1 i = 0 , 1 , . . . , n − 1 ,求出第i天结束时这棵树的强度的期望值。答案模 109+7 10 9 + 7
n105 n ≤ 10 5

分析

对于一个有序点对 (x,y) ( x , y ) ,若在第i天结束时这两点连通,则会对森林的强度作出1的贡献。那么我们要求的实际上就是期望有多少有序点对 (x,y) ( x , y ) 满足 x x y y 连通。
dis(x,y) d i s ( x , y ) 表示 x x y y 的距离,不难发现 (x,y) ( x , y ) 对第i天的贡献为

(n1dis(x,y))i(n1)i ( n − 1 − d i s ( x , y ) ) i _ ( n − 1 ) i _

如果我们预处理出 cnti c n t i 表示有多少点对的距离为i,那么
ansi=d=0n1(n1d)!(n1i)!cntd(n1di)!(n1)! a n s i = ∑ d = 0 n − 1 ( n − 1 − d ) ! ∗ ( n − 1 − i ) ! ∗ c n t d ( n − 1 − d − i ) ! ( n − 1 ) !

不难发现是一个卷积形式,用FFT优化即可。
考虑用点分治来求 cnti c n t i ,发现这也是一个卷积形式,用FFT优化即可。
由于模数不是NTT模数,所以要上任意模数FFT。

代码

#include
#include
#include
#include
#include
#include
#include

typedef long long LL;
typedef long double db;

const int N=100005;
const int MOD=1000000007;
const int B=sqrt(MOD)+1;
const db pi=acos(-1.0);

int n,cnt,last[N],size[N],f[N],root,sum,rev[N*4],L,jc[N],ny[N],d[N],t[N],mx,a[N];
bool vis[N];
struct edge{int to,next;}e[N*2];
struct com
{
    db x,y;

    com operator + (const com &d) const {return (com){x+d.x,y+d.y};}
    com operator - (const com &d) const {return (com){x-d.x,y-d.y};}
    com operator * (const com &d) const {return (com){x*d.x-y*d.y,x*d.y+y*d.x};}
    com operator / (const db &d) const {return (com){x/d,y/d};}
}a1[N*4],b1[N*4],a2[N*4],b2[N*4],c[N*4];
std::vector vec[N*4];

void addedge(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
    e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}

void get_root(int x,int fa)
{
    size[x]=1;f[x]=0;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa||vis[e[i].to]) continue;
        get_root(e[i].to,x);
        size[x]+=size[e[i].to];
        f[x]=std::max(f[x],size[e[i].to]);
    }
    f[x]=std::max(f[x],sum-size[x]);
    if (!root||f[x]void pre()
{
    for (int i=1;i1)
    {
        vec[i].clear();
        for (int k=0;kcos(pi*k/i),sin(pi*k/i)});
    }
}

void FFT(com *a,int f)
{
    for (int i=0;iif (istd::swap(a[i],a[rev[i]]);
    for (int i=1;i1)
    {
        for (int j=0;j1))
        {
            com w=(com){1,0};
            for (int k=0;kif (f==-1) for (int i=0;ivoid calc(int mx,int f)
{
    for (int i=0;i<=mx;i++) a1[i]=(com){t[i],0};
    int lg=0;
    for (L=1;L<=mx*2;L<<=1,lg++);
    for (int i=0;i>1]>>1)|((i&1)<<(lg-1));
    pre();
    for (int i=mx+1;i0,0};
    FFT(a1,1);
    for (int i=0;i1);
    for (int i=1;i<=std::min(n-1,mx*2);i++) (d[i]+=(LL)(a1[i].x+0.5)%MOD*f)%=MOD;
}

void get(int x,int fa,int dep)
{
    t[dep]++;size[x]=1;mx=std::max(mx,dep);
    for (int i=last[x];i;i=e[i].next)
        if (e[i].to!=fa&&!vis[e[i].to]) get(e[i].to,x,dep+1),size[x]+=size[e[i].to];
}

void solve(int x)
{
    vis[x]=1;
    mx=0;get(x,0,0);
    calc(mx,1);
    for (int i=0;i<=mx;i++) t[i]=0;
    for (int i=last[x];i;i=e[i].next)
        if (!vis[e[i].to])
        {
            mx=0;get(e[i].to,x,1);
            calc(mx,-1);
            for (int j=0;j<=mx;j++) t[j]=0;
        }
    for (int i=last[x];i;i=e[i].next)
    {
        if (vis[e[i].to]) continue;
        sum=size[e[i].to];root=0;get_root(e[i].to,x);
        solve(root);
    }
}

int main()
{
    scanf("%d",&n);
    for (int i=1;iint x,y;scanf("%d%d",&x,&y);
        addedge(x,y);
    }
    sum=n;root=0;get_root(1,0);
    solve(root);
    for (int i=1;i0?MOD:0;
    d[0]=n;
    jc[0]=jc[1]=ny[0]=ny[1]=1;
    for (int i=2;i<=n;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
    for (int i=2;i<=n;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
    for (int i=0;i1-i]%MOD/B,0},b1[i]=(com){(LL)d[i]*jc[n-1-i]%MOD%B,0};
    for (int i=0;i0},b2[i]=(com){ny[i]%B,0};
    int lg=0;
    for (L=1;L<=n*2;L<<=1,lg++);
    for (int i=0;i>1]>>1)|((i&1)<<(lg-1));
    pre();
    for (int i=n;i0,0};
    FFT(a1,1);FFT(b1,1);FFT(a2,1);FFT(b2,1);
    for (int i=0;i1);
    for (int i=0;i0.5)%MOD*B*B%MOD)%=MOD;
    for (int i=0;i1);
    for (int i=0;i0.5)%MOD*B%MOD)%=MOD;
    for (int i=0;i1);
    for (int i=0;i0.5)%MOD)%=MOD;
    printf("%d ",(LL)n*n%MOD);
    for (int i=1;iprintf("%d ",(LL)a[n-1-i]*jc[n-1-i]%MOD*ny[n-1]%MOD);
    return 0;
}

你可能感兴趣的:(点分治,快速傅里叶变换)