Codeforces 981H K Paths 分治FFT+树形dp

题意

给一棵 n n 个节点的树,现在要从树上按顺序选出 k k 条路径(可以相同),满足任意一条边要么被覆盖不超过 1 1 次,要么被覆盖恰好 k k 次,且被覆盖 k k 次的边数不能为 0 0 。问方案。
n,k105 n , k ≤ 10 5

分析

先考虑暴力,我们可以枚举两个端点 u u v v ,然后保证每条选出的路径都包含这两个点之间的路径。
那么现在要从这两个点为根的子树中分别选出 k k 个端点,使得这些端点到根的路径没有公共边。
szv s z v 表示节点 v v 的子树大小, s1,s2,...,sm s 1 , s 2 , . . . , s m 表示 u u 的所有儿子,考虑多项式 (1+szs1x)(1+szs2x)...=(aixi) ( 1 + s z s 1 x ) ( 1 + s z s 2 x ) . . . = ∑ ( a i x i ) ,那么选出 k k 个端点的方案就是 fv=axCxkx! f v = ∑ a x C k x x !
而这个多项式可以用分治FFT来算,复杂度是 O(nlog2n) O ( n l o g 2 n )
如果我们把原树定为有根树,然后算出每个点 v v 为根的 fv f v ,就可以通过树形dp来计算所有两端点不为祖先关系的答案。
对于那些两端点为祖先关系的答案,我们可以枚举深度较小的点 v v ,那么如果我们选了 v v 的一个儿子 u u 子树中的点作为另一个端点,节点 v v 对应的多项式就要乘上 1+(nszv)x1+szux 1 + ( n − s z v ) x 1 + s z u x
显然乘或除以一个一次多项式可以在 O(degree) O ( d e g r e e ) 时间内完成,而注意到不同的 szu s z u 只有 O(n) O ( n ) 种,所以复杂度就是 O(nn) O ( n n )
总的复杂度就是 O(nlog2n+nn) O ( n l o g 2 n + n n )

代码

#include
#include
#include
#include
#include

typedef long long LL;

const int N=100005;
const int MOD=998244353;

int n,k,cnt,last[N],jc[N],ny[N],a[20][N*2],b[N],tot,rev[N*2],size[N],f[N],s[N],L,ans;
struct edge{int to,next;}e[N*2];
struct data{int x,y;}t[N];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

void addedge(int u,int v)
{
    e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
    e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}

bool cmp(data a,data b)
{
    return a.xx;
}

int calc(int *a,int n) 
{
    int ans=0;
    for (int i=0;i<=std::min(n,k-1);i++)
        (ans+=(LL)a[i]*jc[k]%MOD*ny[k-i]%MOD)%=MOD;
    return ans;
}

int ksm(int x,int y)
{
    int ans=1;
    while (y)
    {
        if (y&1) ans=(LL)ans*x%MOD;
        x=(LL)x*x%MOD;y>>=1;
    }
    return ans;
}

void NTT(int *a,int f)
{
    for (int i=0;iif (ifor (int i=1;i1)
    {
        int wn=ksm(3,f==1?(MOD-1)/i/2:MOD-1-(MOD-1)/i/2);
        for (int j=0;j1))
        {
            int w=1;
            for (int k=0;kint u=a[j+k],v=(LL)w*a[j+k+i]%MOD;
                a[j+k]=(u+v)%MOD;a[j+k+i]=(u+MOD-v)%MOD;
                w=(LL)w*wn%MOD;
            }
        }
    }
    int ny=ksm(L,MOD-2);
    if (f==-1) for (int i=0;i*ny%MOD;
}

void solve(int l,int r,int d)
{
    if (l==r) {a[d][0]=1;a[d][1]=t[l].x;return;}
    int mid=(l+r)/2;
    solve(l,mid,d+1);
    for (int i=0;i<=mid-l+1;i++) a[d][i]=a[d+1][i];
    solve(mid+1,r,d+1);
    int lg=0;
    for (L=1;L<=r-l+1;L<<=1,lg++);
    for (int i=0;i>1]>>1)|((i&1)<<(lg-1));
    for (int i=mid-l+2;i0;
    for (int i=r-mid+1;i1][i]=0;
    NTT(a[d],1);NTT(a[d+1],1);
    for (int i=0;i*a[d+1][i]%MOD;
    NTT(a[d],-1);
}

void dfs1(int x,int fa)
{
    size[x]=1;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa) continue;
        dfs1(e[i].to,x);
        size[x]+=size[e[i].to];
        (s[x]+=s[e[i].to])%=MOD;
    }
    tot=0;
    for (int i=last[x];i;i=e[i].next)
        if (e[i].to!=fa) t[++tot].x=size[e[i].to],t[tot].y=s[e[i].to];
    if (!tot) {f[x]=s[x]=1;return;}
    std::sort(t+1,t+tot+1,cmp);
    solve(1,tot,0);
    f[x]=calc(a[0],tot);(s[x]+=f[x])%=MOD;
    a[0][tot+1]=0;
    for (int i=tot+1;i>=1;i--) (a[0][i]+=(LL)a[0][i-1]*(n-size[x])%MOD)%=MOD;
    int w;
    for (int i=1;i<=tot;i++)
    {
        if (t[i].x==t[i-1].x) {(ans+=(LL)w*t[i].y%MOD)%=MOD;continue;}
        for (int j=0;j<=tot+1;j++) b[j]=a[0][j];
        for (int j=1;j<=tot+1;j++) (b[j]+=MOD-(LL)b[j-1]*t[i].x%MOD)%=MOD;
        w=calc(b,tot);
        (ans+=(LL)w*t[i].y%MOD)%=MOD;
    }
}

void dfs2(int x,int fa)
{
    int w=0;
    for (int i=last[x];i;i=e[i].next)
    {
        if (e[i].to==fa) continue;
        dfs2(e[i].to,x);
        (ans+=(LL)w*s[e[i].to]%MOD)%=MOD;
        (w+=s[e[i].to])%=MOD;
    }
}

int main()
{
    n=read();k=read();
    jc[0]=jc[1]=ny[0]=ny[1]=1;
    for (int i=2;i<=k;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
    for (int i=2;i<=k;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
    for (int i=1;iint x=read(),y=read();
        addedge(x,y);
    }
    dfs1(1,0);
    dfs2(1,0);
    printf("%d",ans);
    return 0;
}

你可能感兴趣的:(树形dp,快速傅里叶变换,分治)