BSGS算法 扩展BSGS算法 模板

BSGS算法 可以快速的求出,满足a^x ≡ b(mod p)的最小的非负整数x (p必须为质数,即 gcd(a,p) == 1)

我们先将x拆分成i*m-j的形式(其中m为sqrt(p)向上取整的值,则原式化为  a^{i*m-j} ≡ b (mod p)  )   显然 0<=j

移向后得  a^{i*m} ≡ b*a^{j} (mod p)

我们从  0-m  枚举  j,并将  b*a^{j}  的所有值存入map中

接着在从  1-m  枚举  i,算出所有的  a^{i*m}

如果一个  i  对应的  a^{i*m}  的值已经在 map 中,则表明  i*m-j  为一个解,输出此时的解即可

因为  j<=m,所以求出的解随 i 的增大而减小,所以最先求出的 i 所对的解,即为所求的解

 

根据 poj2417所写   http://poj.org/problem?id=2417 

//#include
#include
#include
#include
#include
#define ll long long
using namespace std;
ll BSGS(ll y,ll z,ll p){
	mapmp;
    mp.clear();
    ll m=ceil(sqrt(p));
    ll tmp=1;
    for(int i=1;i<=m;i++){
      tmp=tmp*y%p;//上一个板子这里用了快速幂,会重复计算很多,但作为一个相信玄学的人,我还是把两个都打上来了  PS:应该上一个板子能解决的这个都能解决
      mp[tmp*z%p]=i; //这里会覆盖掉之前的解,即保存的是最大的i。
    }
    ll res=1,ans;
    for(int i=1;i<=m;i++){  
		res=res*tmp%p;
		if(mp[res]){
           ans=m*i-mp[res]; //因为mp[res]保存的是最大的,所以ans是最小的解。
           return ans;
       }
    }
    return -1;
}
int main(){
    ll t,z,y,p; //注意要开long long 类型,否则会溢出 
    while(~scanf("%lld%lld%lld",&p,&y,&z)){
        ;
        ll ans=BSGS(y,z,p);// y^x = z (mod p)
        if(ans==-1){  //方程无解 
            printf("no solution\n");
        }else{
            printf("%lld\n",ans);
        }
    }
    return 0;
}

另外,使用map可能会造成时间超限,所以有时需要自己手写hash表,但这会用掉大量内存(依然是 poj2417)

//#include
#include
#include
#include
#include
#include
#include
#define ll long long
using namespace std;
const int MOD=1e7;
namespace Ha { //哈希表
    int tot, h[MOD+5], ne[MOD+5],p[MOD+5];
    ll ha[MOD+5];
    void insert(ll x, ll num) { //插入操作
        ll t = num%MOD;
        p[++tot] = x, ha[tot] = num, ne[tot] = h[t], h[t] = tot;
    }
    ll query(ll tar) { //查询操作
        for(int i = h[tar%MOD]; i != -1; i = ne[i])
            if(ha[i] == tar) return p[i];
        return -1;
    }
}
using namespace Ha;
ll BSGS(ll a, ll b, ll p) {
    a %= p, b %= p;
    if(a == 0 && b != 0) return -1; //a%p==0时显然无解
    if(a == 0 && b == 0) return 1;
    if(b == 1) return 0;
    ll m = ceil(sqrt((double)p)), q = 1, x = 1;
    memset(h, -1, sizeof h); //记得清空
    for(ll j = 0; j < m; ++j) insert(j, q*b%p), q = q*a%p; //暴力枚举j并存入表中
    for(ll i = 1, j; i <= m; ++i) {
        x = x*q%p, j = query(x); //在表中找
        if(j != -1) return i*m-j; //找到解了,直接返回
    }
    return -1;
}
int main(){
    ll t,z,y,p; //注意要开long long 类型,否则会溢出 
    while(~scanf("%lld%lld%lld",&p,&y,&z)){
        ;
        ll ans=BSGS(y,z,p);// y^x = z (mod p)
        if(ans==-1){  //方程无解 
            printf("no solution\n");
        }else{
            printf("%lld\n",ans);
        }
    }
    return 0;
}

但是上面的哈希表不知道为什么,在https://ac.nowcoder.com/acm/contest/885/C中会段错误,而且这道题分块很有讲究,所以有更好的板子如下:

#define ll long long
const int MOD=1e7;
ll quickpow(ll a,ll b,ll mod){
    ll ans=1;
    while(b){
        if(b&1)
            ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}

ll hs[MOD], head[MOD], Next[MOD], id[MOD], top;
void insert(ll x, ll y){
    ll k = x % MOD;
    hs[top] = x;
    id[top] = y;
    Next[top] = head[k];
    head[k] = top++;
}
ll find(ll x){
    ll k = x % MOD;
    for (ll i = head[k]; i != -1; i = Next[i])
        if (hs[i] == x)
            return id[i];
    return -1;
}
ll up,down,mul;
//unordered_mapM;
void init(){
	memset(head,-1,sizeof(head));
	top=0;
}
void initBSGS(ll y,ll p){
    //M.clear();
    init();
    up=ceil(pow(p,2.0/3));//在一般情况下,大步小步的值均为 sqrt(p),但该题会有q次询问,每次询问都要在哈希表中进行查找,所以我们要让小步的值尽可能的小
    down=ceil(pow(p,1.0/3));
    ll num=1;
    for(int i=0; i<=up; i++){
        if(i==up)
            mul=num;
        insert(num,i);
        //M[num]=i;
        num=num*y%p;
    }
}

ll BSGS(ll y,ll z,ll p){
    ll num=quickpow(z,p-2,p);
    for(int i=1; i<=down+1; i++){
        num=num*mul%p;
        ll f=find(num);
        if(f!=-1) return i*up-f;
        //if(M.count(num)) return i*up-M[num];
            
    }
    return -1;
}

 

扩展BSGS是在模数不为质数时用的 

直接上代码,转自 https://www.cnblogs.com/TheRoadToTheGold/p/8478697.html

#include
#include
#include
#include 
#include 
 
using namespace std;

typedef long long LL;
 
mapmp;

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}
 
int get_gcd(int a,int b) { return !b ? a : get_gcd(b,a%b); }
 
int Pow(int a,int b,int mod)
{
    int res=1;
    for(;b;a=1LL*a*a%mod,b>>=1)
        if(b&1) res=1LL*res*a%mod;
    return res;
}   
 
int ex_BSGS(int A,int B,int C)
{
    if(B==1) return 0;
    int k=0,tmp=1,d;
    while(1)
    {
        d=get_gcd(A,C);
        if(d==1) break;
        if(B%d) return -1;
        B/=d; C/=d;
        tmp=1LL*tmp*(A/d)%C;
        k++;
        if(tmp==B) return k;
    }
    mp.clear();
    int mul=B;
    mp[B]=0;
    int m=ceil(sqrt(1.0*C));
    for(int j=1;j<=m;++j) 
    {
        mul=1LL*mul*A%C;
        mp[mul]=j;
    }
    int am=Pow(A,m,C);
    mul=tmp;
    for(int j=1;j<=m;++j)
    {
        mul=1LL*mul*am%C;
        if(mp.count(mul)) return j*m-mp[mul]+k;
    }
    return -1;
}
 
int main()
{
    int A,C,B;
    int ans;
    while(1)
    {
        read(A); read(B); read(C); 
        if(!A) return 0;
        ans=ex_BSGS(A,B,C);// y^x = z (mod p)
        if(ans==-1) puts("No Solution");
        else cout<

 

你可能感兴趣的:(算法小笔记)