简要题意:
对总大小为 n n n ,且任意联通块大小不超过 a a a 的带标号无根树森林计数。
由于是带标号拼接,直接算出来EGF然后Exp即可。
代码:
#include
#define ll long long
#define re register
#define cs const
using std::cerr;
using std::cout;
cs int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(int a,int b){return a-b<0?a-b+mod:a-b;}
inline int mul(int a,int b){ll r=(ll)a*b;return r>=mod?r%mod:r;}
inline void Inc(int &a,int b){a+=b-mod;a+=a>>31&mod;}
inline void Dec(int &a,int b){a-=b;a+=a>>31&mod;}
inline void Mul(int &a,int b){a=mul(a,b);}
inline int po(int a,int b){int r=1;for(;b;b>>=1,Mul(a,a))if(b&1)Mul(r,a);return r;}
inline void ex_gcd(int a,int b,int &x,int &y){
if(!b){x=1,y=0;return ;}ex_gcd(b,a%b,y,x);y-=a/b*x;
}inline int Inv(int a){int x,y;ex_gcd(mod,a,y,x);return x+(x>>31&mod);}
cs int bit=19,SIZE=1<<bit|7;
int r[SIZE],*w[bit+1],Log[SIZE];
void init_omega(){
for(int re i=1;i<=bit;++i)
w[i]=new int[1<<(i-1)];
int wn=po(3,(mod-1)>>bit);w[bit][0]=1;
for(int re i=1;i<(1<<(bit-1));++i)
w[bit][i]=mul(w[bit][i-1],wn);
for(int re i=bit-1;i;--i)
for(int re j=0;j<(1<<(i-1));++j)
w[i][j]=w[i+1][j<<1];
for(int re i=2;i<SIZE;++i)
Log[i]=Log[i-1]+((1<<Log[i-1])<i);
}
int fac[SIZE],ifc[SIZE],inv[SIZE];
void init_fac(){
fac[0]=fac[1]=1;
ifc[0]=ifc[1]=1;
inv[0]=inv[1]=1;
for(int re i=2;i<SIZE;++i){
fac[i]=mul(fac[i-1],i);
inv[i]=mul(inv[mod%i],mod-mod/i);
ifc[i]=mul(ifc[i-1],inv[i]);
}
}
int len,inv_len;
void DFT(int *A){
for(int re i=1;i<len;++i)
if(i<r[i])std::swap(A[i],A[r[i]]);
for(int re i=1,d=1;i<len;i<<=1,++d)
for(int re j=0;j<len;j+=i<<1)
if(i<8){
for(int re k=0;k<i;++k){
int &t1=A[j+k],&t2=A[j+k+i];
int t=mul(t2,w[d][k]);
t2=dec(t1,t);Inc(t1,t);
}
}else {
#define work(p) \
{ \
int &t1=A[j+k+p],&t2=A[j+k+i+p]; \
int t=mul(t2,w[d][k+p]); \
t2=dec(t1,t),Inc(t1,t); \
}
for(int re k=0;k<i;k+=8){
work(0);work(1);work(2);work(3);
work(4);work(5);work(6);work(7);
}
}
}void IDFT(int *A){
DFT(A);std::reverse(A+1,A+len);
for(int re i=0;i<len;++i)Mul(A[i],inv_len);
}void init_len(int l){
len=l;inv_len=inv[l];
for(int re i=1;i<l;++i)
r[i]=r[i>>1]>>1|((i&1)?l>>1:0);
}
typedef std::vector<int> Poly;
inline void DFT(Poly &A){DFT(&A[0]);}
inline void IDFT(Poly &A){IDFT(&A[0]);}
int A[SIZE],B[SIZE];
Poly operator*(cs Poly &a,cs Poly &b){
if(!a.size()||!b.size())
return Poly(0);
int deg=a.size()+b.size()-1;
init_len(1<<Log[deg]);
memcpy(A,&a[0],a.size()<<2);
memset(A+a.size(),0,(len-a.size())<<2);
memcpy(B,&b[0],b.size()<<2);
memset(B+b.size(),0,(len-b.size())<<2);
DFT(A),DFT(B);
for(int re i=0;i<len;++i)
Mul(A[i],B[i]);
IDFT(A);return Poly(A,A+deg);
}Poly Deriv(Poly a){
if(!a.size())return Poly(0);
for(int re i=1;i<(int)a.size();++i)
a[i-1]=mul(a[i],i);
a.pop_back();return a;
}Poly Integ(Poly a){
if(!a.size())return Poly(0);
a.push_back(0);
for(int re i=a.size()-1;i;--i)
a[i]=mul(a[i-1],inv[i]);
a[0]=0;return a;
}Poly Inv(cs Poly &a,int lim){
int n=a.size();Poly c,b(1,Inv(a[0]));
for(int re l=4;(l>>2)<lim;l<<=1){
init_len(l);c.resize(l>>1);
for(int re i=0;i<(l>>1);++i)
c[i]=i<n?a[i]:0;
c.resize(l),DFT(c);
b.resize(l),DFT(b);
for(int re i=0;i<l;++i)
Mul(b[i],dec(2,mul(b[i],c[i])));
IDFT(b);b.resize(l>>1);
}return Poly(b.begin(),b.begin()+lim);
}Poly Ln(Poly a,int lim){
a=Deriv(a)*Inv(a,lim);
a.resize(lim+1);return Integ(a);
}Poly Exp(cs Poly &a,int lim){
int n=a.size();Poly c,b(1,1);
for(int re i=2;(i>>1)<lim;i<<=1){
c=Ln(b,i);Dec(c[0],1);
for(int re j=0;j<i;++j)
c[j]=dec(j<n?a[j]:0,c[j]);
b=b*c;b.resize(i);
}return Poly(b.begin(),b.begin()+lim);
}
int n,a;
Poly F;
void Main(){
scanf("%d%d",&n,&a);
init_omega();init_fac();
F.resize(n+1);F[1]=1;
for(int re i=2;i<=a;++i)
F[i]=mul(po(i,i-2),ifc[i]);
F=Exp(F,n+1);
cout<<mul(F[n],fac[n])<<"\n";
}
inline void file(){
#ifdef zxyoi
freopen("forest.in","r",stdin);
#else
#ifndef ONLINE_JUDGE
freopen("forest.in","r",stdin);
freopen("forest.out","w",stdout);
#endif
#endif
}signed main(){file();Main();return 0;}