求长度为 n n n的,每个位置的取值为 [ 1 , k ] [1,k] [1,k]之间的整数的,连续上升子段长度不超过 m m m的序列的个数。
n ≤ 1 e 9 n\le 1e9 n≤1e9
m , k ≤ 5 e 4 m,k\le 5e4 m,k≤5e4
神仙套路题……
首先可以想到一个DP:设 f n f_n fn表示长度为 n n n的序列的答案。方程: f n = ∑ f n − i ( k i ) f_n=\sum f_{n-i}\binom{k}{i} fn=∑fn−i(ik)。
这个DP显然是错误的,因为可能会有两个连续的序列合在一起。
尝试去用一些奇技淫巧来把它容斥掉:状态转移的时候乘上一个系数 g i g_i gi。
对于一个极长长度为 l e n len len的的连续序列,它的真实贡献为 [ 1 ≤ l e n ≤ m ] [1\le len \le m] [1≤len≤m]。设 F ( x ) = ∑ [ 1 ≤ i ≤ m ] x i F(x)=\sum[1\le i\le m]x^i F(x)=∑[1≤i≤m]xi
考虑这样一个序列会被计算多少次。于是有 ∑ i ≥ 1 G i ( x ) = F ( x ) \sum_{i\ge 1} G^i(x)=F(x) ∑i≥1Gi(x)=F(x),解一下方程就可以得到 G ( x ) G(x) G(x)。
设 g i = [ x i ] G ( x ) g_i=[x^i]G(x) gi=[xi]G(x),方程变成了 f n = ∑ f n − i g i ( k i ) f_n=\sum f_{n-i}g_i\binom{k}{i} fn=∑fn−igi(ik)。
然后它就对了…………
后面就是一个常系数线性递推的事情了。
using namespace std;
#include
#include
#include
#include
#define N 524288
#define ll long long
#define mo 998244353
#define check(x) printf("%d\n",x);
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
int nN,re[N];
void setlen(int n){
int bit=0;
for (nN=1;nN<=n;nN<<=1,++bit);
for (int i=1;i<nN;++i)
re[i]=re[i>>1]>>1|(i&1)<<bit-1;
}
void clear(int A[],int n){memset(A,0,sizeof(int)*n);}
void copy(int A[],int a[],int n){clear(A,nN);for (int i=0;i<=n;++i) A[i]=a[i];}
void dft(int A[],int flag){
for (int i=0;i<nN;++i)
if (i<re[i])
swap(A[i],A[re[i]]);
static int wnk[N];
for (int i=1;i<nN;i<<=1){
ll wn=qpow(3,flag==1?(mo-1)/(2*i):mo-1-(mo-1)/(2*i));
wnk[0]=1;
for (int k=1;k<i;++k)
wnk[k]=wnk[k-1]*wn%mo;
for (int j=0;j<nN;j+=i<<1)
for (int k=0;k<i;++k){
ll x=A[j+k],y=(ll)A[j+k+i]*wnk[k];
A[j+k]=(x+y)%mo;
A[j+k+i]=(x-y)%mo;
}
}
if (flag==-1)
for (int i=0,invn=qpow(nN);i<nN;++i)
A[i]=(ll)A[i]*invn%mo;
for (int i=0;i<nN;++i)
A[i]=(A[i]+mo)%mo;
}
void multi(int c[],int a[],int b[],int n,int an=-1,int bn=-1){
if (an==-1) an=n-1;
if (bn==-1) bn=n-1;
static int A[N],B[N],C[N];
setlen(an+bn);
copy(A,a,an),dft(A,1);
if (a==b)
for (int i=0;i<nN;++i)
C[i]=(ll)A[i]*A[i]%mo;
else{
copy(B,b,bn),dft(B,1);
for (int i=0;i<nN;++i)
C[i]=(ll)A[i]*B[i]%mo;
}
dft(C,-1);
for (int i=0;i<=min(n-1,an+bn);++i)
c[i]=C[i];
}
void getinv(int c[],int a[],int n){
static int b[N],g[N];
int nn=1;for (;nn<n;nn<<=1);
clear(b,nn),clear(g,nn);
b[0]=qpow(a[0]);
for (int i=1;i<n;i<<=1){
multi(g,b,b,i*2,i-1,i-1);
multi(g,g,a,i*2,i*2-1,min(n,i*2-1));
for (int j=0;j<i*2;++j)
b[j]=(2ll*b[j]-g[j]+mo)%mo;
}
for (int i=0;i<n;++i)
c[i]=b[i];
}
void getrev(int A[],int a[],int n){for (int i=0;i<=n;++i) A[i]=a[n-i];}
void getdiv(int c[],int a[],int b[],int n,int m){
static int A[N],B[N],C[N];
clear(B,n-m+1),clear(A,n-m+1);
getrev(A,a,n),getrev(B,b,m);
getinv(B,B,n-m+1);
multi(C,A,B,n-m+1);
getrev(c,C,n-m);
}
void getmod(int c[],int a[],int b[],int n,int m){
static int D[N];
getdiv(D,a,b,n,m);
multi(D,D,b,n,n-m,m);
for (int i=0;i<m;++i)
c[i]=(a[i]-D[i]+mo)%mo;
}
int n,m,k;
int fac[N],ifac[N];
void initC(int n){
fac[0]=1;
for (int i=1;i<=n;++i)
fac[i]=(ll)fac[i-1]*i%mo;
ifac[n]=qpow(fac[n]);
for (int i=n-1;i>=0;--i)
ifac[i]=(ll)ifac[i+1]*(i+1)%mo;
}
ll C(int m,int n){return m<n?0:m==n?1:(ll)fac[m]*ifac[n]%mo*ifac[m-n]%mo;}
int f[N],g[N],a[N];
int q[N],mx;
void chang(int n){
if (n==0){
q[mx=0]=1;
return;
}
if (n&1){
chang(n-1);
for (int i=mx;i>=0;--i)
q[i+1]=q[i];
q[0]=0;
if (mx+1<k)
mx++;
else{
getmod(q,q,g,mx+1,k);
mx=k-1;
}
}
else{
chang(n>>1);
multi(q,q,q,2*mx+1,mx,mx);
if (2*mx<k)
mx*=2;
else{
getdiv(f,q,g,mx*2,k);
getmod(q,q,g,mx*2,k);
mx=k-1;
}
}
}
int main(){
//freopen("in.txt","r",stdin);
freopen("senior.in","r",stdin);
freopen("senior.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
initC(k);
for (int i=1;i<=m;++i)
f[i]=1;
g[0]=1;
for (int i=1;i<=m;++i)
g[i]=f[i];
getinv(g,g,k+1);
multi(g,f,g,k+1);
for (int i=1;i<=k;++i)
a[i]=(ll)g[i]*C(k,i)%mo;
g[k]=1;
for (int i=1;i<=k;++i)
g[k-i]=(mo-a[i])%mo;
chang(n-(-k+1));
ll ans=q[k-1];
printf("%lld\n",ans);
/*
dp[0]=1;
for (int i=1;i<=n;++i){
for (int j=1;j<=k && i-j>=0;++j)
(dp[i]+=dp[i-j]*C(k,j)%mo*g[j])%=mo;
}
*/
return 0;
}