【学习笔记】[AGC063D] Many CRT

有点难。

首先判掉 gcd ⁡ ( c , d ) > 1 \gcd(c,d)>1 gcd(c,d)>1的情况。记 M = lcm 1 ≤ k ≤ n ( c + d k ) M=\text{lcm}_{1\le k\le n}(c+dk) M=lcm1kn(c+dk)

我们将同余式变形为: d x ≡ d a + k d b ( m o d c + k d ) dx\equiv da+kdb\pmod{c+kd} dxda+kdb(modc+kd)

d x ≡ d a − b c ( m o d c + k d ) dx\equiv da-bc\pmod{c+kd} dxdabc(modc+kd)

将模数合并,即 d x ≡ d a − b c ( m o d M ) dx\equiv da-bc\pmod{M} dxdabc(modM)

这等价于求 方程 d x − M y = d a − b c dx-My=da-bc dxMy=dabc x x x最小非负整数解,可以直接变形为 x = d a − b c + M y d x=\frac{da-bc+My}{d} x=ddabc+My。因此只要求出 M   m o d   d M\bmod d Mmodd即可。

考虑 lcm \text{lcm} lcm算法的本质,发现对于 > max ⁡ ( n , d ) >\max(n,d) >max(n,d)的质因数可以直接乘起来,对于 ≤ max ⁡ ( n , d ) \le \max(n,d) max(n,d)的质因数可以暴力跳,这样就做完了。

同余这方面还是比较神奇的,可以看看这道题 [WC2021] 斐波那契 。

大佬的题解写的太完美了我就咕掉了。

remark \text{remark} remark 在细节处理上出了问题。。。看来还是太菜了。

#include
#define fi first
#define se second
#define ll long long
#define pb push_back
#define inf 0x3f3f3f3f
using namespace std;
const int mod=998244353;
ll n,a,b,c,d;
ll gcd(ll x,ll y){
    return y==0?x:gcd(y,x%y);
}
ll lcm(ll x,ll y){
    return x/gcd(x,y)*y;
}
const int N=1e6+5;
int prime[N],cnt;
bool vs[N];
void init(int n){
    for(int i=2;i<=n;i++){
        if(vs[i]==0)prime[++cnt]=i;
        for(int j=1;j<=cnt&&prime[j]<=n/i;j++){
            vs[i*prime[j]]=1;
            if(i%prime[j]==0)break;
        }
    }
}
ll nums[N];
void exgcd(ll a,ll b,ll &x,ll &y,ll r){
    if(b==0){
        x=1,y=0,r=1;
        return;
    }exgcd(b,a%b,y,x,r),y-=x*(a/b);
}
ll fpow(ll x,ll y=mod-2){
    ll z(1);
    for(;y;y>>=1){
        if(y&1)z=z*x%mod;
        x=x*x%mod;
    }return z;
}
signed main(){
	ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>a>>b>>c>>d;
    ll g=gcd(c,d);
    if(b%g){
        cout<<-1;
        return 0;
    }
    ll ok=0,res=0,a2=a;
    if(g>1){
        ok=1,b/=g,c/=g,d/=g,a2=a%g,a/=g;
    }
    init(n);
    for(int i=0;i<n;i++)nums[i]=c+i*d;
    ll tM=1;
    for(int i=0;i<n;i++){
        if((ll)(1e12)/tM<nums[i]){
            tM=1e12;
        }
        else{
            tM*=nums[i];
        }
    }
    if(tM!=1e12){
        ll M=1;
        for(int i=0;i<n;i++)M=lcm(M,nums[i]);
        ll A=d,B=M,x,y,r;
        exgcd(A,B,x,y,r),x=(x%M+M)%M;
        res=((__int128)(d*a-b*c)%M*x%M+M)%M;
        if(ok){
            cout<<(res*g+a2)%mod;
        }
        else{
            cout<<res%mod;
        }
        return 0;
    }
    ll M=1,M2=1;
    for(int i=1;i<=cnt;i++){
        ll p=prime[i];int cnt=0;
        if(d%p==0)continue;
        ll x,y,A=d,B=p,r;
        exgcd(A,B,x,y,r);
        x=(-x%p*c%p+p)%p;if(x>=n)continue;
        for(int j=x;j<n;j+=p){
            int tmp=0;
            while(nums[j]%p==0)tmp++,nums[j]/=p;
            cnt=max(cnt,tmp);
        }for(int j=1;j<=cnt;j++)M=M*p%d,M2=M2*p%mod;
    }
    for(int i=0;i<n;i++)M=M*(nums[i]%d)%d,M2=M2*(nums[i]%mod)%mod;
    for(ll i=0;i<=d;i++){
        if((a*d-b*c+M*i)%d==0&&(__int128)tM*i+d*a-b*c>=0){
            res=((d*a-b*c)%mod+M2*i%mod)*fpow(d)%mod;
            break;
        }
    }res=(res+mod)%mod;
    if(ok){
        cout<<(res*g+a2)%mod;
    }else{
        cout<<res;
    }
}

你可能感兴趣的:(算法)