对于初值在 P P P范围内的序列 A ( x ) 和 B ( x ) A(x) 和B(x) A(x)和B(x),一次卷积之后大小不超过 n P 2 nP^2 nP2。找三个数论模数分别NTT之后,用中国剩余定理合并。不用大数或者__int128,可以参考下面的做法。
https://blog.csdn.net/u014609452/article/details/68058602
板子题:P4245 【模板】任意模数NTT
code:
// luogu-judger-enable-o2
#include
using namespace std;
typedef long long ll;
const int maxn = 2e6 + 10 , g = 3;
const double eps = 1e-3;
int mod;
int rev[maxn];
ll qmul(ll a, ll b, ll c){
a %= c; b %= c;
ll ret = a * b - (ll)((long double)a * b / c + eps) * c;
return ret < 0 ? ret + c : ret;
}
inline ll qpow(ll a,ll b,ll P){
ll ret = 1;
a %= P;
for(;b;b>>=1,a=a*a%P) if(b&1) ret = ret * a % P;
return ret;
}
const int m1 = 998244353,m2 = 1004535809,m3 = 469762049;
const ll _M = (ll)m1 * m2;
const int inv1 = qpow(m1 % m2,m2-2,m2);
const int inv2 = qpow(m2 % m1,m1-2,m1);
const int inv12 = qpow(_M % m3,m3-2,m3);
ll CRT(ll a1, ll a2, ll a3){
ll ret = qmul(a1 * m2 % _M, inv2, _M);
(ret += qmul(a2 * m1 % _M, inv1, _M)) %= _M;
ll ans = ((a3 - ret) % m3 + m3) % m3 * inv12 % m3;
ans = (ans % mod * (_M % mod) % mod + ret % mod) % mod;
return ans;
}
struct NTT{
int P;
int num,w[2][maxn];
void Pre(int _P,int m){
num = m; P = _P;
int wn = qpow(g,(P-1)/num,P);
int _wn = qpow(wn,P-2,P);
w[1][0] = w[0][0] = 1;
for(int i = 1;i<num;i++) w[1][i] = (ll)w[1][i-1] * wn % P;
for(int i = 1;i<num;i++) w[0][i] = (ll)w[0][i-1] * _wn % P;
}
void DFT(int* a,int N,int r){
for(int i = 1;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1;i<N;i<<=1)
for(int j = 0;j<N;j+=(i<<1))
for(int k = 0;k<i;k++){
// cout<<"r = "<
// cout<<"ijk = "<
int x = a[j+k],y = 1LL * a[i+j+k] * w[r][num/(i<<1)*k] % P;
a[j+k] = (x+y) % P; a[i+j+k] = (x + P - y) % P;
}
if(!r) for(int i = 0,Inv = qpow(N,P-2,P);i<N;i++) a[i] = 1LL * a[i] * Inv % P;
}
}ntt[3];
int A[maxn],B[maxn],C[maxn],D[maxn],tmp[3][maxn];
int main(){
int n,m;
scanf("%d%d%d",&n,&m,&mod);
for(int i = 0;i<=n;i++) scanf("%d",&A[i]);
for(int i = 0;i<=m;i++) scanf("%d",&B[i]);
int N;
for(N=1;N<=(n+m+1);N<<=1);
ntt[0].Pre(m1,N);
ntt[1].Pre(m2,N);
ntt[2].Pre(m3,N);
int L = 0; while(!(N>>L&1)) L++; L--;
for(int i = 1;i<N;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<L);
for(int i = 0;i<3;i++){
memcpy(C,A,sizeof(int) * (N+1));
memcpy(D,B,sizeof(int) * (N+1));
ntt[i].DFT(C,N,1);
ntt[i].DFT(D,N,1);
for(int j = 0;j<N;j++) {
tmp[i][j] = (ll)C[j] * D[j] % ntt[i].P;
// cout<
}
ntt[i].DFT(tmp[i],N,0);
}
for(int i = 0;i< n + m + 1;i++) {
printf("%lld ",CRT(tmp[0][i],tmp[1][i],tmp[2][i]));
}
return 0;
}
又称MTT 。
大概就是把系数拆成 f ( x ) = P ∗ k ( x ) + r ( x ) f(x) = \sqrt{P}*k(x)+r(x) f(x)=P∗k(x)+r(x)的形式,然后再还原回去。朴素的版本一共需要7次 D F T DFT DFT。卷积之后数据在 n P nP nP级别,为避免精度误差,需要用到 l o n g d o u b l e long~double long double,并预处理单位方根。
板子题:P4245 【模板】任意模数NTT
code:不知道为什么跑得比三模数NTT慢啊
// luogu-judger-enable-o2
#include
#include
#include
#include
#include
#include
using namespace std;
const int maxn = 6e5 + 10;
const long double pi = acos((long double)-1.0);
const double eps = 1e-3;
typedef long long ll;
int mod,M;
struct cp{
long double r,i;
cp(long double _r=0,long double _i=0):r(_r),i(_i){}
cp operator + (cp x) { return cp(r+x.r,i+x.i);}
cp operator - (cp x) { return cp(r-x.r,i-x.i); }
cp operator * (cp x) { return cp(r*x.r-i*x.i,r*x.i+i*x.r);}
cp conj() {return cp(r,-i); }
};
int A[maxn],B[maxn],rev[maxn];
cp a[maxn],b[maxn],k1[maxn],k2[maxn],r1[maxn],r2[maxn];
cp s1[maxn],s2[maxn],s3[maxn],w[2][maxn];
int ans[maxn];
void get_wn(int N){
for(int i = 0;i<N;i++){
w[1][i] = cp(cos(2*pi/N * i),sin(2*pi/N*i));
w[0][i] = w[1][i].conj();
}
}
void DFT(cp* a,int N,int r){
for(int i = 0;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int l = 2; l<=N;l<<=1){
int m = l >> 1;
for(int j = 0; j <= N; j += l)
for(int k = 0;k<m;k++){
cp u = w[r][N/l*k] * a[j+k+m];
a[j+k+m] = a[j+k] - u;
a[j+k] = a[j + k] + u;
}
}
if(r == 0) for(int i = 0;i<N;i++) a[i].r /= N;
}
void MTT(cp* a,cp* b,int N){
for(int i = 0;i<N;i++){
k1[i] = cp(A[i]/M,0); r1[i] = cp(A[i]%M,0);
k2[i] = cp(B[i]/M,0); r2[i] = cp(B[i]%M,0);
}
DFT(k1,N,1); DFT(r1,N,1);
DFT(k2,N,1); DFT(r2,N,1);
for(int i = 0;i<N;i++) {
s1[i] = k1[i] * k2[i];
s2[i] = k1[i] * r2[i] + k2[i] * r1[i];
s3[i] = r1[i] * r2[i];
}
DFT(s1,N,0); DFT(s2,N,0); DFT(s3,N,0);
for(int i = 0;i<N;i++){
int x1 = (ll)(s1[i].r + 0.5) % mod * M * M % mod;
int x2 = (ll)(s2[i].r + 0.5) % mod * M % mod;
int x3 = (ll)(s3[i].r + 0.5) % mod;
ans[i] = (x1 + x2);
if(ans[i]>=mod) ans[i]-=mod;
ans[i] += x3;
if(ans[i]>=mod) ans[i]-=mod;
}
}
int main(){
int n,m;
scanf("%d%d%d",&n,&m,&mod);
M = sqrt(mod + eps);
for(int i = 0;i<= n;i++) scanf("%d",&A[i]);
for(int i = 0;i<= m;i++) scanf("%d",&B[i]);
int L,N;
for(L=0,N=1;N<=(n+m+2);L++,N<<=1);
for(int i = 1;i<N;i++) rev[i] = (rev[i>>1] >> 1) | ((i&1) << (L-1));
get_wn(N);
MTT(a,b,N);
for(int i = 0;i<n+m+1;i++){
printf("%d ",ans[i]);
}
return 0;
}