UOJ86 mx的组合数

大概看到题的时候就会做了。好厉害的题。

组合数模质数p等于某值的方案数,很容易想到利用卢卡斯定理。然后要使p进制下每一位对出的结果的乘积在模p意义下为某值,数位dp一波就好。

暴力转移是每位p^2的,但是转移的形式是c[i*j]+=a[i]*b[j],可以考虑找原根R,这就变成了c[logRi+logRj]+=a[logRi]*b[logRj],这里NTT就好了,模数还刚好是998244353。

然后你让我写。就很麻烦了。

先是一波高精度处理。找原根并预处理阶。在阶的基础上转换卷积形式以及NTT板子。哦还有预处理一波阶乘来求组合数。最后就是数位dp。

写还有调搞得我心力憔悴。

代码:

#include 
#include 
#include 
#define ll long long
#define AwD 998244353
int p;
struct bigint{
	int v[105],len;
}n,l,r;
void read(bigint&a){
	char s[35];scanf("%s",s);
	a.len=strlen(s);
	for(int i=0;i<a.len;i++) a.v[a.len-i]=s[i]-'0';
}
int operator%(bigint a,int b){
	int res=0;
	for(int i=a.len;i;i--) res=(res*10+a.v[i])%b;
	return res;
}
bigint operator/(bigint a,int b){
	for(int i=a.len;i;i--){
		a.v[i-1]+=a.v[i]%b*10;
		a.v[i]/=b;
	}
	while(a.len>1&&!a.v[a.len]) a.len--;
	return a;
}
bigint incr(bigint a){
	a.v[1]++;
	for(int i=1;i<a.len;i++) if(a.v[i]>=10){
		a.v[i]-=10;a.v[i+1]++;
	}
	if(a.v[a.len]>=10){
		a.v[a.len]-=10;a.v[++a.len]=1;
	}
	return a;
}
bool zero(bigint a){
	return a.len==1&&!a.v[1];
}
void trs(bigint&a){
	int b[105],n=0;
	while(!zero(a)){
		b[++n]=a%p;
		a=a/p;
	}
	for(int i=1;i<=n;i++) a.v[i]=b[i];
	a.len=n;
}
void exp0(bigint&a,int L){
	for(int i=a.len+1;i<=L;i++) a.v[i]=0;
}
int kth[30005],rk[30005],R;
void findR(){
	R=1;
	while(1){
		for(int i=1;i<p;i++) rk[i]=0;
		bool flag=kth[0]=1;
		for(int i=1;i<p;i++){
			kth[i]=kth[i-1]*R%p;
			if(rk[kth[i]]){
				flag=0;break;
			}
			rk[kth[i]]=i;
		}
		if(!flag){
			R++;continue;
		}
		rk[1]=0;
		return;
	}
}
ll pw(ll x,ll y){
	if(y<0) y+=AwD-1;
	if(!y) return 1;
	ll res=pw(x,y>>1);
	(res*=res)%=AwD;
	if(y&1) (res*=x)%=AwD;
	return res;
}
ll ntt(ll*a,int n,int d){
	int i,j,k;
	ll w,t,u,v;
	for(i=(n>>1),j=1;j<n;j++){
		if(i<j) t=a[i],a[i]=a[j],a[j]=t;
		for(k=(n>>1);i&k;i^=k,k>>=1);i^=k;
	}
	for(k=2;k<=n;k<<=1){
		w=pw(3,(AwD-1)/k*d);
		for(i=0;i<n;i+=k){
			t=1;
			for(j=i;j<i+(k>>1);j++){
				u=a[j];v=t*a[j+(k>>1)]%AwD;
				a[j]=(u+v)%AwD;a[j+(k>>1)]=(u-v+AwD)%AwD;t=t*w%AwD;
			}
		}
	}
}
ll t1[65555],t2[65555];
void multi(int*a,int*b,int*res){
	int res0=0;
	for(int i=0;i<p;i++){
		res0=(res0+1ll*a[i]*b[0])%AwD;
		if(i) res0=(res0+1ll*a[0]*b[i])%AwD;
	}
	//for(int i=0;i
	//for(int i=0;i
	for(int i=1;i<p;i++) t1[rk[i]]=a[i],t2[rk[i]]=b[i];
	int l=1,invl;while(l<p-1) l<<=1;invl=pw(l<<=1,-1);
	for(int i=p-1;i<l;i++) t1[i]=t2[i]=0;
	//for(int i=0;i
	//for(int i=0;i
	ntt(t1,l,1);ntt(t2,l,1);
	for(int i=0;i<l;i++) (t1[i]*=t2[i])%=AwD;
	ntt(t1,l,-1);
	for(int i=0;i<l;i++) t1[i]=t1[i]*invl%AwD;
	//for(int i=0;i
	for(int i=1;i<p;i++) res[i]=0;
	for(int i=0;i<l;i++) (res[kth[i%(p-1)]]+=t1[i])%=AwD;
	res[0]=res0;
	//for(int i=0;i
}
int fac[30005],inv[30005];
int C(int n,int m){
	return n<m?0:fac[n]*inv[m]%p*inv[n-m]%p;
}
int L,dp[105][30005],tmp[30005],cur;
void init(){
	fac[0]=1;for(int i=1;i<p;i++) fac[i]=fac[i-1]*i%p;
	inv[p-1]=kth[p-1-rk[fac[p-1]]];for(int i=p-1;i;i--) inv[i-1]=inv[i]*i%p;
	for(int i=0;i<p;i++) dp[0][i]=i==1;
	for(int i=1;i<L;i++){
		for(int j=0;j<p;j++) tmp[j]=0;
		for(int j=0;j<p;j++) tmp[C(j,n.v[i])]++;
		multi(dp[i-1],tmp,dp[i]);
	}
}
void solve(bigint a,int*ans){
	//for(int i=0;i
	//printf("solving...\n");
	for(int i=0;i<p;i++) ans[i]=0;
	cur=1;
	for(int i=L;i;i--){
		for(int j=0;j<p;j++) tmp[j]=0;
		for(int j=0;j<a.v[i];j++) tmp[C(j,n.v[i])]++;
		//for(int j=0;j
		multi(tmp,dp[i-1],tmp);
		//for(int j=0;j
		for(int j=0;j<p;j++) (ans[j*cur%p]+=tmp[j])%=AwD;
		(cur*=C(a.v[i],n.v[i]))%=p;
	}
	//printf("solved\n");
}
int ans1[30005],ans2[30005]; 
int main(){
	scanf("%d",&p);read(n);read(l);read(r);r=incr(r);
	//printf("reading ok\n");
	trs(n);trs(l);trs(r);
	//printf("transforming ok\n");
	L=std::max(n.len,std::max(l.len,r.len));
	//printf("L=%d\n",L);
	exp0(n,L);exp0(l,L);exp0(r,L);
	//for(int i=L;i;i--) printf("%d ",n.v[i]);printf("\n");
	//for(int i=L;i;i--) printf("%d ",l.v[i]);printf("\n");
	//for(int i=L;i;i--) printf("%d ",r.v[i]);printf("\n");
	//printf("0-expanding ok\n");
	findR();
	//printf("%d\n",R);
	//for(int i=0;i
	//for(int i=1;i
	//printf("---\n");
	init();solve(r,ans1);solve(l,ans2);
	for(int i=0;i<p;i++) printf("%d\n",(ans1[i]-ans2[i]+AwD)%AwD);
}

你可能感兴趣的:(UOJ86 mx的组合数)