多项式模板

应要求放一份模板

#include<bits/stdc++.h>
#define pb push_back
#define cs const
#define poly vector<int>
using namespace std;
cs int Mod = 998244353;
int add(int a, int b){ return a + b >= Mod ? a + b - Mod : a + b; }
int dec(int a, int b){ return a - b < 0 ? a - b + Mod : a - b; }
int mul(int a, int b){ return 1ll * a * b % Mod; }
int ksm(int a, int b){ int as=1; for(;b;b>>=1,a=mul(a,a)) if(b&1) as=mul(as,a); return as; }
void Add(int &a, int b){ a = add(a, b); }
void Dec(int &a, int b){ a = dec(a, b); }
void Mul(int &a, int b){ a = mul(a, b); }
cs int K = 18, N = 1 << K | 5;
int n, fc[N], ifc[N], iv[N];
int *w[K+1], rv[N], bit, up;
void output(poly a){ for(int i=0; i<a.size(); i++) cout<<a[i]<<" "; puts(""); }
void fc_init(int n){
	fc[0]=fc[1]=ifc[0]=ifc[1]=1;
	for(int i=2; i<=n; i++) fc[i]=mul(fc[i-1],i);
	for(int i=2; i<=n; i++) ifc[i]=mul(ifc[i-1],iv[i]);
}
void NTT_init(){
	for(int i=1; i<=K; i++) w[i]=new int[1<<(i-1)];
	int wn=ksm(3,(Mod-1)/(1<<K)); w[K][0]=1;
	for(int i=1; i<(1<<(K-1)); i++) w[K][i]=mul(w[K][i-1],wn);
	for(int i=K-1;i;i--) for(int j=0;j<(1<<(i-1));j++) w[i][j]=w[i+1][j<<1]; 
	iv[0]=iv[1]=1; for(int i=2; i<=(1<<K); i++) iv[i]=mul(Mod-Mod/i,iv[Mod%i]);
} 
void init(int deg){
	up=1; bit=0; while(up<deg) up<<=1,++bit; 
	for(int i=0; i<up; i++) rv[i]=(rv[i>>1]>>1)|((i&1)<<(bit-1));
}
poly operator - (poly a, poly b){
	if(a.size()<b.size()) a.resize(b.size());
	for(int i=0; i<(int)b.size(); i++) Dec(a[i],b[i]); return a;
}
poly divx(poly a){
	for(int i=0; i<(int)a.size(); i++) a[i]=a[i+1];
	return a.pop_back(), a;
}
void NTT(poly &a, int typ=1){
	for(int i=0; i<up; i++) if(i<rv[i]) swap(a[i],a[rv[i]]);
	for(int i=1,l=1;i<up;i<<=1,++l)
	for(int j=0; j<up; j+=(i<<1))
	for(int k=0; k<i; k++){
		int x=a[k+j], y=mul(w[l][k],a[k+j+i]);
		a[k+j]=add(x,y); a[k+j+i]=dec(x,y);
	} if(typ==-1){
		reverse(a.begin()+1,a.end());
		for(int i=0,iv=ksm(up,Mod-2); i<up; i++) Mul(a[i],iv);
	}
}
poly operator * (poly a, poly b){
	int dg=a.size()+b.size()-1; 
	if(a.size()<=32||b.size()<=32){
		poly c(dg,0);
		for(int i=0; i<(int)a.size(); i++)
		for(int j=0; j<(int)b.size(); j++)
		Add(c[i+j],mul(a[i],b[j])); return c;
	} init(dg); a.resize(up); b.resize(up); NTT(a); NTT(b);
	for(int i=0; i<up; i++) Mul(a[i],b[i]); 
	NTT(a,-1); a.resize(dg); return a;
}
poly inv(poly a, int deg){
	poly b(1,ksm(a[0],Mod-2)),c;
	for(int lim=2; (lim>>1)<deg; lim<<=1){
		c.resize(lim); init(lim<<1);
		for(int i=0; i<lim; i++) c[i]=i<(int)a.size()?a[i]:0;
		c.resize(up); b.resize(up); NTT(c); NTT(b);
		for(int i=0; i<up; i++) Mul(b[i],dec(2,mul(b[i],c[i])));
		NTT(b,-1); b.resize(lim); 
	} b.resize(deg); return b;
}
poly deriv(poly a){
	for(int i=0; i+1<(int)a.size(); i++) a[i]=mul(i+1,a[i+1]);
	a.pop_back(); return a;
}
poly integ(poly a){
	a.pb(0);
	for(int i=a.size()-1;i;i--) a[i]=mul(iv[i],a[i-1]);
	a[0]=0; return a;
}
poly ln(poly a, int dg){
	if(dg==-1) dg=a.size();
	a=integ(deriv(a)*inv(a,dg)); a.resize(dg); return a;
}
poly Exp(poly a, int dg){
	if(dg==-1) dg = a.size(); poly b(1,1), c;
	for(int lim=2; lim<(dg<<1); lim<<=1){
		c=ln(b,lim); Dec(c[0],1);
		for(int i=0; i<lim; i++) c[i]=dec(i<(int)a.size()?a[i]:0,c[i]);
		b=b*c; b.resize(lim);
	} b.resize(dg); return b;
}
poly polypw(poly a, int k, int dg){ 
	a = ln(a,dg); 
	for(int i=0; i<(int)a.size(); i++) Mul(a[i],k);
	return Exp(a,dg);
}
void operator -= (poly &a, cs poly &b){
	if(a.size() < b.size()) a.resize(b.size());
	for(int i=0; i<(int)b.size(); i++) Dec(a[i],b[i]); 
}
poly operator / (poly a, poly b){
	int deg = a.size()-b.size()+1;
	reverse(a.begin(),a.end()), a.resize(deg);
	reverse(b.begin(),b.end()), a = a * inv(b,deg);
	a.resize(deg); reverse(a.begin(),a.end()); return a;	
}
poly operator % (poly a, poly b){
	if(a.size()<b.size()) return a;
	a -= b * (a / b); a.resize(b.size()-1); return a;
}
poly f[N];
void work(int x, int l, int r, int *S){
	if(l==r){ f[x].pb(dec(0,S[l])); f[x].pb(1); return; }
	int mid=(l+r)>>1; 
	work(x<<1,l,mid,S); work(x<<1|1,mid+1,r,S);
	f[x] = f[x<<1] * f[x<<1|1];
}
void solve(int x, int l, int r, poly F){
	if(l==r) return coe[l]=F[0],void();
	int mid=(l+r)>>1;
	solve(x<<1,l,mid,F%f[x<<1]);
	solve(x<<1|1,mid+1,r,F%f[x<<1|1]);
} 

你可能感兴趣的:(tmp)