Modular exponentiation

1. 递归实现:

unsigned long long modular_pow_recursive(unsigned long long base, int exponent, int modulus)
{
    if (modulus == 1)
    {   
        return 0;
    }   
    if(exponent == 0)
    {   
        return 1;
    }   
    
    if(exponent & 1 == 1)
    {   
        return base * modular_pow_recursive(base, exponent - 1, modulus) % modulus;
    }   

    unsigned long long ret =  modular_pow_recursive(base, exponent >> 1,  modulus);
    return ((ret * ret) % modulus;
}

刚开始的时候,最后两行写成了如下的形式:

    int ret =  modular_pow_recursive(base, exponent >> 1,  modulus) %modulus;
    return (ret * ret) % modulus

测试的时候发现结果不对,查了半天,才发现是有 overflow了。

改成前面代码里的形式,后者:

    int ret =  modular_pow_recursive(base, exponent >> 1,  modulus) %modulus;
    return ((unsigned long long )ret * ret) % modulus
就可以了。

2. 迭代实现:

unsigned long long modular_pow(unsigned long long base, int exponent, int modulus)
{
    if(modulus == 1)
    {   
        return 0;
    }   

    if(exponent == 0)
    {   
        return 1;
    }   

    unsigned long long  result = 1;
    while(exponent > 0)
    {   
        if( exponent & 1)
        {
            result = (result * base) % modulus;
        }
        base = (base * base) % modulus;
        exponent = exponent >> 1;
    }   

    return result;
}

3. 一个比较容易想到的实现:

unsigned long long modular_pow_rude(unsigned long long base, int exponent, int modulus)
{
    int i;
    unsigned long long result = 1;
    for(i = exponent; i > 0; i--)
    {   
        result = (result * base) % modulus;
    }   
    
    return result;
}


测试代码:

int main(int argc, char **argv)
{
    int base = 19; 
    int exponent = 78;
    int modulus = 199879;

    int ret = modular_pow_recursive(base, exponent, modulus);
    ret = modular_pow(base, exponent, modulus);
    return ret;
}

References:
https://en.wikipedia.org/wiki/Modular_exponentiation

你可能感兴趣的:(Modular exponentiation)