牛客国庆Day5B(蒙哥马利算法)

题目链接:https://ac.nowcoder.com/acm/contest/205/B?&headNav=www

原来窝被这算法虐过么= =!

这个算法其实是用来解决RSA中计算公钥的问题的,是个底层常数优化算法。。

主要解决两个问题:

  • 蒙哥马利约减,即 t p − 1   ( m o d   m ) tp^{-1}\ (mod\ m) tp1 (mod m)

  • 蒙哥马利乘模,即 x y   ( m o d   m ) x y\ (mod\ m) xy (mod m)

这个算法的目的就是避免取模和除法来降低常数
算法过程,直接看杜教板子比看网上的博客来得简单。。

首先预处理几个参数:

b为数制,这里所有的数都以b进制的蒙哥马利表示法表示
p = b k p=b^k p=bk p ≥ m p\ge m pm
i n v = m − 1   m o d   p inv=m^{-1}\ mod\ p inv=m1 mod p
r 2 = p 2   m o d   m r2=p^2\ mod\ m r2=p2 mod m

取b进制是为了方便进行位运算,所以b一般取2,然后p取 2 64 2^{64} 264的话取模的时候自然溢出就行了,比较方便
然后关于蒙哥马利表示法,就是x的蒙哥马利表示为 x p   ( m o d   m ) xp\ (mod\ m) xp (mod m)
所有数的运算都在此表示法下进行

蒙哥马利约减:

t ≤ m 2 t\le m^2 tm2,可以用 ( t − ( t ∗ i n v   m o d   p ) ∗ m ) / p (t-(t*inv\ mod\ p)*m )/p (t(tinv mod p)m)/p 代替 t p − 1   m o d   m tp^{-1}\ mod \ m tp1 mod m
证明:
因为 t − t ∗ i n v ∗ m ≡ t + t ( − m − 1 ∗ m ) ≡ 0   ( m o d   p ) t-t*inv*m\equiv t+t(-m^{-1}*m)\equiv0\ (mod\ p) ttinvmt+t(m1m)0 (mod p)
所以 p ∣ ( t + t ∗ i n v ∗ m ) p|(t+t*inv*m) p(t+tinvm),即 p ∣ ( t + ( t ∗ i n v   m o d   p ) ∗ m ) p|(t+(t*inv\ mod\ p)*m) p(t+(tinv mod p)m)
故有 ( t + ( t ∗ i n v   m o d   p ) ∗ m ) / p ≡ t p − 1 + ( t ∗ i n v   m o d   p ) p − 1 m ≡ t p − 1   ( m o d   m ) (t+(t*inv\ mod\ p)*m )/p\equiv tp^{-1}+(t*inv\ mod \ p)p^{-1}m\equiv tp^{-1}\ (mod \ m) (t+(tinv mod p)m)/ptp1+(tinv mod p)p1mtp1 (mod m)
在数值上,由于
( t − ( t ∗ i n v   m o d   p ∗ m ) / p ) ≤ m 2 / p ≤ m (t-(t*inv\ mod \ p*m)/p)\le m^2/p\le m (t(tinv mod pm)/p)m2/pm
( t − ( t ∗ i n v   m o d   p ∗ m ) / p ) ≥ − m (t-(t*inv\ mod \ p*m)/p)\ge -m (t(tinv mod pm)/p)m
所以只需判断正负即可

然后就可以用蒙哥马利约减来求出一个数的蒙哥马利表示了,直接对 x ∗ r 2 x*r2 xr2进行约减即可
从蒙哥马利表示法中还原出原数也只需做一个约减就可以了

蒙哥马利乘模:


x ^ = x p   ( m o d   m ) \hat{x}=xp\ (mod\ m) x^=xp (mod m)
y ^ = y p   ( m o d   m ) \hat{y}=yp\ (mod\ m) y^=yp (mod m)
那么在蒙哥马利表示法中, x y ^ = x y p \hat{xy}=xyp xy^=xyp
x y ^   ≡   x ^ y ^ / p   ( m o d   m ) \hat{xy} \ \equiv \ \hat{x}\hat{y}/p\ (mod \ m) xy^  x^y^/p (mod m)
直接对 x ∗ y x*y xy进行一次约减就可以了

基本就这么多。。说那么多其实只要有板子就可以。。
这里就贴杜教的板子,窝的板子实在是太丑了。。。

 
 
 
 
 
 
 
 

#include 
using namespace std;

typedef long long ll;
typedef unsigned long long u64;
typedef __int128_t i128;
typedef __uint128_t u128;

struct Mod64 {
    Mod64() :n_(0) {}
    Mod64(u64 n) :n_(init(n)) {}
    static u64 init(u64 w) { return reduce(u128(w) * r2); }
    static void set_mod(u64 m) {
        mod = m; assert(mod & 1);
        inv = m; for (int i = 0; i < 5; ++i) inv *= 2 - inv * m;
        r2 = -u128(m) % m;
    }
    static u64 reduce(u128 x) {
        u64 y = u64(x >> 64) - u64((u128(u64(x)*inv)*mod) >> 64);
        return ll(y)<0 ? y + mod : y;
    }
    Mod64& operator += (Mod64 rhs) { n_ += rhs.n_ - mod; if (ll(n_)<0) n_ += mod; return *this; }
    Mod64 operator + (Mod64 rhs) const { return Mod64(*this) += rhs; }
    Mod64& operator -= (Mod64 rhs) { n_ -= rhs.n_; if (ll(n_)<0) n_ += mod; return *this; }
    Mod64 operator - (Mod64 rhs) const { return Mod64(*this) -= rhs; }
    Mod64& operator *= (Mod64 rhs) { n_ = reduce(u128(n_)*rhs.n_); return *this; }
    Mod64 operator * (Mod64 rhs) const { return Mod64(*this) *= rhs; }
    u64 get() const { return reduce(n_); }
    static u64 mod, inv, r2;
    u64 n_;
};

u64 Mod64::mod, Mod64::inv, Mod64::r2;


int t, k;
u64 A0, A1, M0, M1, C, M;

void Run()
{
    scanf("%d", &t);
    while (t--)
    {
        scanf("%llu%llu%llu%llu%llu%llu%d", &A0, &A1, &M0, &M1, &C, &M, &k);
        Mod64::set_mod(M);
        Mod64 a0(A0), a1(A1), m0(M0), m1(M1), c(C), ans(1), a2(0);
        ans *= a0; ans *= a1;
        for (int i = 2; i <= k; ++i)
        {
            a2 = a1;
            a1 = m0 * a1 + m1 * a0 + c;
            a0 = a2;
            ans *= a1;
        }
        printf("%llu\n", ans.get());
    }
}

int main()
{
    Run();
    return 0;
}

你可能感兴趣的:(其他算法)