参考:扩展大步小步法解决离散对数问题
离散对数主要是求解这样一类问题:
a x ≡ b ( m o d m ) a^x \equiv b \pmod m ax≡b(modm) 大概就是 ( m o d m ) \pmod m (modm)意义下的对数
这里只考虑 ( m , a ) (m,a) (m,a)为1的情况。一般来说,给出的m是一个质数。
设 x = A ⌈ p ⌉ + B x = A \lceil \sqrt{p} \rceil + B x=A⌈p⌉+B,其中 0 ≤ B < ⌈ p ⌉ 0 \leq B < \lceil \sqrt{p} \rceil 0≤B<⌈p⌉, 0 ≤ A < ⌈ p ⌉ 0 \leq A < \lceil \sqrt{p} \rceil 0≤A<⌈p⌉。则方程变成
a A ⌈ p ⌉ + B ≡ b ( m o d p ) a^{A\lceil \sqrt{p} \rceil + B} \equiv b \pmod p aA⌈p⌉+B≡b(modp)
两边同时乘以 A ⌈ p ⌉ A \lceil \sqrt{p} \rceil A⌈p⌉的逆元,则方程变为
a B ≡ b ⋅ a − A ⌈ p ⌉ ( m o d p ) a^{B} \equiv b\cdot a^{-A\lceil \sqrt{p} \rceil} \pmod p aB≡b⋅a−A⌈p⌉(modp)
由于 A A A, B B B都是 O ( p ) O(\sqrt{p}) O(p)级别的数,可以预处理出左边的所有值,然后按从小到大的顺序枚举 A A A,查找左边是否有值对应。采用手写 h a s h hash hash或者C++11的 u n o r d e r e d unordered unordered_ m a p map map就可以 O ( 1 ) O(1) O(1)时间查询。这样的复杂度是 O ( p + p ) O(\sqrt{p} + \sqrt{p}) O(p+p)的。
如果需要对 n n n个 b b b求解这个模方程,考虑将 x x x 表示为 A ∗ K + B A*K + B A∗K+B,则 A ≤ ⌊ p k ⌋ A \leq \lfloor \frac{p}{k}\rfloor A≤⌊kp⌋,预处理左边之后,只需要在右边重复查找 n n n次即可。时间复杂度为 O ( K + ⌊ p k ⌋ ∗ n ) O(K + \lfloor \frac{p}{k}\rfloor * n) O(K+⌊kp⌋∗n)。
对于 x y = n m o d    m x^y = n \mod m xy=nmodm 求解 y y y。
LL discrete_log(int x,int _n,int m){
unordered_map<LL,int>rec;
int s=(int)(sqrt((double)m));
for(; (LL)s*s<=m;)++s;
LL cur=1;
for(int i=0;i<s;++i){
rec[cur]=i;
// cur=cur*x%m;
MUL(cur,x,m);
}
LL mul=cur;
cur=1;
for(int i=0;i<s;++i){
LL more=(LL)_n;
MUL(more,qpow(cur,m-2,m),m);
if(rec.count(more)){
return i*s +rec[more];
}
// cur= cur*mul%m;
MUL(cur,mul,m);
}
return -1;
}
zoj 18.1月赛 E
#include
using namespace std;
#define lson rt<<1
#define rson rt<<1|1
typedef long long ll;
const int mod = 1e9+7;
const int _mod = mod - 1;
const int _P = 5;
const int maxn = 1e5+10;
ll a[maxn],Log[maxn];
struct node{
int l,r;
ll sum,add,mul;
}tr[maxn<<2];
void Add(ll& x,ll y,int P){
x+=y;
if(x>=P) x -= P;
}
void Mul(ll& x,ll y,int P){
x *= y;
if(x>=P) x %= P;
}
ll qpow(ll a,ll b,ll P){
ll ret = 1;
while(b){
if(b&1) Mul(ret,a,P);
Mul(a,a,P);
b>>=1;
}
return ret;
}
void up(int rt){
tr[rt].sum = tr[lson].sum + tr[rson].sum;
if(tr[rt].sum >= _mod) tr[rt].sum -= _mod;
}
void down(int rt){
int m = (tr[rt].l + tr[rt].r)>>1;
// cout<<"down "<