扩展Crt

好像去年这个时候我就已经看过一遍了。。但是noi的时候一点印象没有就GG了。。补知识点的时候发现自己还是不会,就稍微学了一下。。。

扩展crt就是求满足
在这里插入图片描述
的一组x(模数不要求互质)

做法就是假设你搞出了前k组的一个最小正整数解x,想推出前k+1的解
M = ∏ i = 1 k m i M=\prod_{i=1}^k m_i M=i=1kmi,然后所有 x + t ∗ M x+t*M x+tM都是满足的
那么对于新填的一组方程 x ≡ a k + 1 ( m o d   m k + 1 ) x\equiv a_{k+1}(mod\ m_{k+1}) xak+1(mod mk+1)我们要找的是令 x + t ∗ M ≡ a k + 1 ( m o d   m k + 1 ) x+t*M\equiv a_{k+1}(mod\ m_{k+1}) x+tMak+1(mod mk+1)最小的 t t t,exgcd求出最小解即可。

啊,数学真是太奇妙了…

UPD:贴上noid2t1代码(快乐考试打的splay…)

#include
#include
#include
#include//long long!!!!!!!!!!!!!!!!!!
#include//多组数据!!!!!!!!!!!
#include
#define INF 214748364700000000ll
#define N 100020
using namespace std;
typedef long long LL;
inline LL read(){
	LL x=0,f=1;char c;
	do c=getchar(),f=c=='-'?-1:f; while(!isdigit(c));
	do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
	return x*f;
}
int T,n,m,cnt;
LL a[N],h[N],p[N],g[N];
bool t1;
void exgcd(LL a,LL b,LL &x,LL &y){
	if(!b){
		y=0,x=1;
		return;
	}
	exgcd(b,a%b,y,x);
	y=y-a/b*x;
}
LL gcd(LL a,LL b){
    return b?gcd(b,a%b):a;
}
inline LL GetNi(LL k,LL p){
    LL x,y;
    exgcd(k,p,x,y);
    return (x+p)%p;
}
struct Data{
    LL a,m;
    Data(){}
    Data(LL a,LL m):a(a),m(m){}
}da[1000020];
struct Node{
    Node *fa,*ch[2];
    int siz,cnt;
    LL x;
    Node(LL);
    inline void maintain(){
        siz=ch[0]->siz+ch[1]->siz+cnt;
    }
    inline int dir(){
        if(fa->ch[0]==this) return 0;
        if(fa->ch[1]==this) return 1;
        return -1;
    }
    inline int cmp(LL k){
        if(x==k) return -1;
        return k<x?0:1;
    }
}*root,*null,*tmp;
queue<Node*>q;
Node::Node(LL _):x(_){
    siz=cnt=1;
    fa=ch[0]=ch[1]=null;
}
inline void Del(Node *&x){
    if(x!=NULL && x!=null)
        q.push(x);
    x=null;
}
inline Node* New(LL x){
    if(q.empty()) return new Node(x);
    static Node *k;
    k=q.front();q.pop();
    Del(k->ch[0]);Del(k->ch[1]);
    k->cnt=k->siz=1;k->x=x;
    return k;
}
inline void initn(){
    Del(root);
    null=new Node(0);
    null->siz=null->cnt=0;
    null->fa=null->ch[0]=null->ch[1]=null;
    root=null;
}
inline void Rotate(Node *x,int d){
    Node *k=x->ch[d^1];
    x->ch[d^1]=k->ch[d];
    if(x->ch[d^1]!=null) x->ch[d^1]->fa=x;
    k->ch[d]=x;
    if(x->fa!=null) x->fa->ch[x->dir()]=k;
    k->fa=x->fa;x->fa=k;
    x->maintain();k->maintain();
}
inline void Splay(Node *x,Node *y){
    while(x->fa!=y){
        if(x->fa->fa!=y && x->dir()==x->fa->dir())
            Rotate(x->fa->fa,x->dir()^1);
        Rotate(x->fa,x->dir()^1);
    }
    if(y==null) root=x;
}
void Insert(LL x,Node *&k,Node *fa){
    if(k==null){
        k=new Node(x);
        k->fa=fa;
        tmp=k;
        return;
    }
    int d=k->cmp(x);
    if(!~d){
        k->cnt++;k->siz++;
        tmp=k;
    }
    else Insert(x,k->ch[d],k);
    k->maintain();
}
LL Lower(LL k,Node *x){
    if(x==null) return -INF;
    if(k<x->x) return Lower(k,x->ch[0]);
    else return max(x->x,Lower(k,x->ch[1]));
}
LL Upp(LL k,Node *x){
    if(x==null) return INF;
    if(k>=x->x) return Upp(k,x->ch[1]);
    else return min(x->x,Upp(k,x->ch[0]));
}
inline void AddNew(LL x){
    Insert(x,root,null);
    Splay(tmp,null);
}
void print(Node *x){
    if(x==null) return;
    print(x->ch[0]);
    printf("(%lld %d)",x->x,x->cnt);
    print(x->ch[1]);
}
Node* LowerP(LL k,Node *x){
    if(x==null) return null;
    if(k<=x->x) return LowerP(k,x->ch[0]);
    Node *t=LowerP(k,x->ch[1]);
    return t==null?x:t;
}
Node* UpperP(LL k,Node *x){
    if(x==null) return null;
    if(k>=x->x) return UpperP(k,x->ch[1]);
    Node *t=UpperP(k,x->ch[0]);
    return t==null?x:t;
}
inline void Delete(LL k){
    static Node *a,*b;
    a=LowerP(k,root);b=UpperP(k,root);
    Splay(a,null);Splay(b,a);
    if(root->ch[1]->ch[0]->cnt>1) root->ch[1]->ch[0]->cnt--,root->ch[1]->ch[0]->siz--;
    else Del(root->ch[1]->ch[0]),root->ch[1]->ch[0]=null;
    root->ch[1]->maintain();root->ch[0]->maintain();
}
LL ans;
inline LL mul(LL a,LL b,LL p){
    LL sum=0;
    a=a%p;b=b%p;
    while(b){
        if(b&1) sum=(sum+a)%p;
        a=(a+a)%p;
        b>>=1;
    }
    return sum%p;
}
inline LL calc(){
    LL M=da[1].m,x=da[1].a;
    for(int i=2;i<=cnt;i++){
        LL g=gcd(M,da[i].m),t=((da[i].a%da[i].m-x%da[i].m)%da[i].m+da[i].m)%da[i].m;
        if(t%g!=0) return -1;
        LL xx,y;
        exgcd(M,da[i].m,xx,y);
        xx=mul(xx,t/g,da[i].m/g);x=xx*M+x;
        assert(M>=0);
        M=M*(da[i].m/g);
        x=(x%M+M)%M;
    }
    return x%M;
}
inline void solve(){
    initn();
    ans=cnt=0;
    AddNew(-INF);AddNew(INF);
    for(int i=1;i<=m;i++)
        AddNew(a[i]);
    for(int i=1;i<=n;i++){
        LL t=Lower(h[i],root);
        if(t!=-INF) Delete(t);
        else{
            t=Upp(-INF,root);
            Delete(t);
        }
        ///t*x=h[i](mod p[i])
        LL pp=gcd(t,p[i]);
        if(h[i]%pp!=0) return void(printf("-1\n"));
        da[++cnt]=Data(mul(h[i]/pp,GetNi(t,p[i]),p[i]),p[i]/pp);
        AddNew(g[i]);
    }
    printf("%lld\n",calc());
}
inline void solve1(){
    initn();
    ans=0;
    AddNew(-INF);AddNew(INF);
    for(int i=1;i<=m;i++)
        AddNew(a[i]);
    for(int i=1;i<=n;i++){
        LL t=Lower(h[i],root);
        if(t!=-INF) Delete(t);
        else{
            t=Upp(-INF,root);
            Delete(t);
        }
        if(h[i]%t==0) ans=max(ans,h[i]/t);
        else ans=max(ans,h[i]/t+1);
        AddNew(g[i]);
    }
    printf("%lld\n",ans);
}
int main(){
	T=read();
	while(T--){
		t1=true;
		n=read();m=read();
		for(int i=1;i<=n;i++) h[i]=read();
		for(int i=1;i<=n;i++) p[i]=read(),t1=t1&(p[i]==1);
		for(int i=1;i<=n;i++) g[i]=read();
		for(int i=1;i<=m;i++) a[i]=read();
		if(t1) solve1();
        else solve();
	}
	return 0;
}

你可能感兴趣的:(算法讲解)