Baby step Giant step算法

题意: 求满足a^x=b(mod n)的最小的整数x。

分析: 很多地方写到n是素数的时候可以用 Baby step,Giant step, 其实研究过 Baby step,Giant step算法以后,你会发现  它能解决    “ n与a互质”的情况,而并不是单纯的n是素数的情况。如果a与n不是互质的,那么我们需要处理一下原方程,让a与n互质,然后再用 Baby step,Giant step解出x即可 

Baby step,Giant step算法思想: 对于a与n互质,那么则有a^phi(n)=1(mod n),   对于n是素数phi(n) == n-1, 否则phi(n) < n-1, 所以x的取值只要在0----n-2之中取就可以了。

当n很小时,可以直接枚举,但当n很大时,肯定会超时,Baby step,Giant step就是用了一种O(sqrt(n)*log(n))的方法枚举了所有的0-----n-2。令m = sqrt(n);

我们可以预处理出a^0,a^1,.........a^m,都放入哈希表中, 然后  (a^m)^i+v(哈希表里的其中一个值)就一定是解,每次枚举i(0-----m-1),计算出v,判断v是否出现在哈希表中,如果有就是解。  对于m为什么取sqrt(n)是为了复杂度的平衡,这一点是跟分块算法很相似的。

对于a与n不互质的情况分析: 令 t = gcd(a,n),那么a与n都约去t,当然b也要约去t(不能约去就无解),约去一个t以后方程就变为   aa*a^(x-1) = bb(mod nn), (其中  aa = a/t    bb = b/t    nn = n/t) , 这里nn还可能与a不互质,那么我们一直拿出一个新的a对(a, bb, nn)约去t,直到a与nnn....(nnn...表示约去若干次t以后的n)互质。以下用(用三个字母表示约去若干次后,如bbb) 则 结果为aa^ c*a^(x-c) = bbb(mod nnn),      我们让等式左右分别乘以aa^c关于nnn的逆元       变为 a^(x-c) = w     (mod  nnn) ,    w = bbb *(aa^c)^(-1) 。   a^x = w  (mod n)可以用 bbb *(aa^c)^(-1)Baby step,Giant step直接求出,如果有解那把未知数+c。

具体看代码中的cal函数。

注意: 在以上过程中x有可能<c,所以我们必须每约去一个t就要特判一下当前情况aa 与 bb就说明当前c是解。

上面是完全复制别人的博客,接下来的代码是我自己结合理解写的,感觉写得不错

typedef long long ll;
#define maxn 100000+5
struct poi
{
	int id;
	ll val;
	friend bool operator < (const poi& A,const poi& B)
	{
		return A.val!=B.val?A.val<B.val:A.id<B.id;
	}
}dat[maxn];

ll exgcd(ll a,ll b,ll& x,ll& y)
{
	if(b==0)
	{
		x=1;
		y=0;
		return a;
	}
	ll t=exgcd(b,a%b,y,x);
	y-=a/b*x;
	return t;
}

ll inv(ll a,ll n)
{
	ll x,y;
	ll d=exgcd(a,n,x,y);
	if(d!=1)return -1;
	return (x%n+n)%n;
}

int bise(int l,int r,ll k)
{
	while(l<=r)
	{
		int m=(l+r)>>1;
		if(dat[m].val==k)return dat[m].id;
		else if(dat[m].val<k)l=m+1;
		else r=m-1;
	}
	return -1;
}

ll qlow(ll a,ll n,ll m)
{
	ll ans=1;
	while(n)
	{
		if(n&1)ans=ans*a%m;
		a=a*a%m;
		n>>=1;
	}
	return ans;
}

int BSGS(ll k,ll n,ll p)
{
	ll t;
	for(int i=0,t=1;i<100;i++,t=t*k%p)
		if(t==n)
			return i;
	
	ll v=1;
	int d=0;
	while((t=__gcd(k,p))!=1)
	{
		if(n%t)return -1;
		d++;
		p/=t;
		n/=t;
		v=k/t*v%p;
	}

	ll m=(ll)ceil(sqrt(p));
	for(int i=0,t=1;i<m;i++,t=t*k%p)
	{
		dat[i].id=i;
		dat[i].val=t;
	}
	sort(dat,dat+m);
	int cnt=1;
	for(int i=1;i<m;i++)
		if(dat[i].val!=dat[cnt-1].val)
			dat[cnt++]=dat[i];

	ll km=inv(k,p);
	km=qlow(km,m,p);
	n=n*inv(v,p)%p;
	for(int i=0;i<m;i++,n=n*km%p)
	{
		int pos=bise(0,cnt-1,n);
		if(pos>=0)return i*m+pos+d;
	}
	return -1;
}


你可能感兴趣的:(数论,BSGS)