4555: [Tjoi2016&Heoi2016]求和
Time Limit: 40 Sec
Memory Limit: 128 MB
Submit: 116
Solved: 97
[ Submit][ Status][ Discuss]
Description
在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。
现在他想计算这样一个函数的值:
S(i, j)表示第二类斯特林数,递推公式为:
S(i, j) = j ∗ S(i − 1, j) + S(i − 1, j − 1), 1 <= j <= i − 1。
边界条件为:S(i, i) = 1(0 <= i), S(i, 0) = 0(1 <= i)
你能帮帮他吗?
Input
Output
输出f(n)。由于结果会很大,输出f(n)对998244353(7 × 17 × 223 + 1)取模的结果即可。1 ≤ n ≤ 100000
Sample Input
3
Sample Output
87
分治+NTT,思路好
令F[i]=∑(0≤j≤i)S(i,j)*(2^j)*(j!),则f[n]=∑(0≤i≤n)F[i]。
首先,第二类斯特林数S(i,j)的意义是将i个数分到j个无序集合的方案数。
那么F[i]的含义是将i个数分到任意个有序集合的方案数,并且枚举每一个集合选或不选。
假设第一个集合有i-j个数,则其他的集合共有j个数,所以得到递推式F[i]=∑(1≤j≤i)F[j]*C(i,j)*2。
将组合数展开,两面同除以i!,得F[i]/i!=∑(F[j]/j!)*2/(i-j)!
可以发现等式右边是一个卷积的形式,而模数比较特殊,所以可以用NTT。
还有一个问题就是等式左右都有F数组,用分治可以解决,每次处理[l,mid]对[mid+1,r]的影响。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<cstring>
#include<algorithm>
#define F(i,j,n) for(int i=j;i<=n;i++)
#define D(i,j,n) for(int i=j;i>=n;i--)
#define ll long long
#define N 400005
#define mod 998244353
using namespace std;
int n;
ll ans,a[N],b[N],f[N],fac[N],inv[N],rev[N];
inline ll getpow(ll x,ll y)
{
ll ret=1;
for(;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;
return ret;
}
void ntt(ll *a,int n,int flg)
{
F(i,0,n-1) if (i<rev[i]) swap(a[i],a[rev[i]]);
for(int m=2;m<=n;m<<=1)
{
int mid=m>>1;
ll wn=getpow(3,((mod-1)/m*flg+mod-1)%(mod-1));
for(int i=0;i<n;i+=m)
{
ll w=1;
F(j,0,mid-1)
{
ll u=a[i+j],v=a[i+j+mid]*w%mod;
a[i+j]=(u+v)%mod;a[i+j+mid]=(u-v+mod)%mod;
w=w*wn%mod;
}
}
}
if (flg==-1)
{
ll inv=getpow(n,mod-2);
F(i,0,n-1) a[i]=a[i]*inv%mod;
}
}
void solve(int l,int r)
{
if (l==r) return;
int mid=(l+r)>>1;
solve(l,mid);
int n=mid-l+r-l,len=1;
while ((1<<len)<n) len++;
n=(1<<len);
F(i,0,n-1) a[i]=b[i]=0;
F(i,l,mid) a[i-l]=f[i];
F(i,1,r-l) b[i-1]=inv[i]*2%mod;
F(i,1,n-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
ntt(a,n,1);ntt(b,n,1);
F(i,0,n-1) a[i]=a[i]*b[i]%mod;
ntt(a,n,-1);
F(i,mid+1,r) f[i]=(f[i]+a[i-l-1])%mod;
solve(mid+1,r);
}
int main()
{
scanf("%d",&n);
fac[0]=inv[0]=1;
F(i,1,n) fac[i]=fac[i-1]*i%mod,inv[i]=getpow(fac[i],mod-2);
f[0]=1;
solve(0,n);
F(i,0,n) ans=(ans+f[i]*fac[i]%mod)%mod;
cout<<ans<<endl;
return 0;
}