[BZOJ4555][Tjoi2016&Heoi2016]求和(NTT)

题目描述

传送门

题目大意:
f(n)=ni=0ij=0S(i,j)×2j×(j!)
其中 S(i,j) 为第二类斯特林数,递推公式为:
S(i,j)=j×S(i1,j)+S(i1,j1),1ji1
边界条件为: S(i,i)=1(0i),S(i,0)=0(1i)

题解

感觉这题给出递推公式就是满满的恶意【有本事你自己推出来通项系列】
第二类斯特林数的通项公式为

S(i,j)=1j!k=0j(1)kCkj(jk)i

将这个公式带入式子可以得到
i=0nj=0i1j!k=0j(1)kCkj(jk)i×2j×(j!)

=i=0nj=0i2jk=0j(1)kj!k!(jk)!(jk)i

可以发现 (jk)i 这个东西是比较难搞的,所以可以考虑把循环顺序换一下
j(j!)2jk(1)kk!1(jk)!i(jk)i

若令 F(j)=(j!)2jk(1)kk!1(jk)!i(jk)i ,实际上就是要求出 F 的每一项,进而求出 F(1..n) 的和
f(n)=1n!ini,g(n)=(1)nn! ,其实 F 就可以化成一个卷积的形式
F(j)=(j!)2jkf(jk)g(k)

利用预处理阶乘、逆元还有等比数列的通项公式预处理 f,g ,就可以直接用NTT求解了
还需要考虑的一个问题是上下界问题,例如原式中j,k的越界。实际上这是对答案没有影响的,因为当j>i时第二类斯特林数都为0,通项公式就保证了这一点,而k>j时组合数也为0,所以所有循环的上下界均为[0..n]即可

代码

#include
#include
#include
#include
#include
using namespace std;
#define LL long long
#define Mod 998244353
#define N 300005

int lim,m,n,L,R[N];
LL mul[N],mi[N],inv[N],invmul[N],a[N],b[N],ans;

LL fast_pow(LL a,int p)
{
    LL ans=1LL;
    for (;p;p>>=1,a=a*a%Mod)
        if (p&1)
            ans=ans*a%Mod;
    return ans;
}
void init()
{
    mul[0]=1LL;mi[0]=1LL;
    for (int i=1;i<=lim;++i) mul[i]=mul[i-1]*(LL)i%Mod,mi[i]=mi[i-1]*2LL%Mod;
    inv[1]=1LL;
    for (int i=2;i<=n;++i) inv[i]=inv[Mod%i]*(Mod-Mod/i)%Mod;
    invmul[0]=1LL;
    for (int i=1;i<=lim;++i) invmul[i]=invmul[i-1]*inv[i]%Mod;
    a[0]=b[0]=1LL;
    for (int i=1,opt=-1;i<=lim;++i,opt=-opt)
    {
        b[i]=opt*invmul[i]%Mod;
        if (i==1) a[i]=(LL)lim+1;
        else a[i]=invmul[i]*(fast_pow((LL)i,lim+1)-1)%Mod*inv[i-1]%Mod;
    }
}
void FFT(LL *a,int opt)
{
    for (int i=0;iif (ifor (int k=1;k1)
    {
        LL wn=fast_pow(3LL,(Mod-1)/(k<<1));
        for (int i=0;i1))
        {
            LL w=1LL;
            for (int j=0;j*wn%Mod)
            {
                LL x=a[i+j],y=w*a[i+j+k]%Mod;
                a[i+j]=(x+y)%Mod,a[i+j+k]=(x-y+Mod)%Mod;
            }
        }
    }
    if (opt==-1) reverse(a+1,a+n);
}

int main()
{
    scanf("%d",&lim);
    m=lim<<1;
    for (n=1;n<=m;n<<=1) ++L;
    for (int i=0;i>1]>>1)|((i&1)<<(L-1));
    init();
    FFT(a,1);FFT(b,1);
    for (int i=0;i<=n;++i) a[i]=a[i]*b[i]%Mod;
    FFT(a,-1);
    for (int i=0;i<=lim;++i)
    {
        LL now=mul[i]*mi[i]%Mod*a[i]%Mod*inv[n]%Mod;
        ans=(ans+now)%Mod;
    }
    ans=(ans+Mod)%Mod;
    printf("%lld\n",ans);
}

你可能感兴趣的:(题解,FFT/NTT)