【洛谷P4245】任意模数NTT

任意模数NTT

首先我们取三个模数,使得它们的乘积大于 nP2 n P 2 7226+1 7 ∗ 2 2 6 + 1 998244353 998244353 479221+1 479 ∗ 2 21 + 1 这三个数就挺合适的,它们互质且原根都是3。
然后对于结果的每一位,我们就得到了中国剩余定理形式的式子:

ansa1(modm1) a n s ≡ a 1 ( mod m 1 )

ansa2(modm2) a n s ≡ a 2 ( mod m 2 )

ansa3(modm3) a n s ≡ a 3 ( mod m 3 )

当然, m1m2m3 m 1 m 2 m 3 敲大的,你肯定不能暴力合并,所以可以先把前两项合并了,得到 ansA(modM) a n s ≡ A ( mod M ) ,然后就有 Mx+Aa3(modm3) M x + A ≡ a 3 ( mod m 3 ) ,然后求出 x=M1(a3A) x = M − 1 ( a 3 − A ) 之后,得到 ans=Mx+A a n s = M x + A ,此时在计算的时候直接模题目给定的模数。

代码

#include
using namespace std;
#define RI register int
int read() {
    int q=0;char ch=' ';
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
    return q;
}
typedef long long LL;
const int N=262150,mm[3]={7*(1<<26)+1,998244353,479*(1<<21)+1},G=3;
int n,m,kn=1,len,mod;
int a[N],b[N],k1[N],k2[N],ans[3][N],rev[N];
int ksm(int x,int y,int p) {
    int re=1;
    for(;y;y>>=1,x=1LL*x*x%p) if(y&1) re=1LL*re*x%p;
    return re;
}
LL ksc(LL x,LL y,LL p) {return (x*y-(LL)((long double)x/p*y+1e-8)*p+p)%p;}
void NTT(int *a,int n,int p,int x) {
    for(RI i=0;iif(rev[i]>i) swap(a[i],a[rev[i]]);
    for(RI i=1;i1) {
        int gn=ksm(G,(p-1)/(i<<1),p);
        for(RI j=0;j1) {
            int g=1,t1,t2;
            for(RI k=0;k1LL*g*gn%p) {
                t1=a[j+k]%p,t2=1LL*g*a[j+i+k]%p;
                a[j+k]=(t1+t2)%p,a[j+i+k]=(t1-t2+p)%p;
            }
        }
    }
    if(x==1) return;
    int inv=ksm(n,p-2,p);reverse(a+1,a+n);
    for(RI i=0;i1LL*a[i]*inv%p;
}
void work(int o) {
    for(RI i=0;i1),NTT(k2,kn,mm[o],1);
    for(RI i=0;i1LL*k1[i]*k2[i]%mm[o];
    NTT(ans[o],kn,mm[o],-1);
}
int main()
{
    n=read(),m=read(),mod=read();
    for(RI i=0;i<=n;++i) a[i]=read();
    for(RI i=0;i<=m;++i) b[i]=read();
    while(kn<=n+m) kn<<=1,++len;
    for(RI i=1;i>1]>>1)|((i&1)<<(len-1));
    for(RI i=0;i<3;++i) work(i);
    LL M=1LL*mm[0]*mm[1];
    LL kl1=ksc(mm[1],ksm(mm[1]%mm[0],mm[0]-2,mm[0]),M);
    LL kl2=ksc(mm[0],ksm(mm[0]%mm[1],mm[1]-2,mm[1]),M);
    for(RI i=0;i<=n+m;++i) {
        int t0=ksm(ans[0][i],mm[1]-2,mm[1]),t1=ksm(ans[1][i],mm[0]-2,mm[0]);
        LL A=(ksc(kl1,ans[0][i],M)+ksc(kl2,ans[1][i],M))%M;
        LL k=((ans[2][i]-A)%mm[2]+mm[2])%mm[2]*ksm(M%mm[2],mm[2]-2,mm[2])%mm[2];
        printf("%lld ",((M%mod)*(k%mod)%mod+A%mod)%mod);
    }
    return 0;
}

你可能感兴趣的:(数学)