快速指数运算方法

问题描述

    我们在一些常用的运算里免不了要计算某些指数函数。比如说给定两个正整数a, b,要求a**b,即a的b次方。这个问题看起来很简单,最直接的办法就是我连续乘以a,b次,得到的就是这个结果。这种方法的时间复杂度也比较低,相当于O(N)。实际上,我们还有更好的办法,使得它的时间复杂度达到O(logN)。

 

分析

    实际上这个问题本身并不是很复杂,从一开始看的时候似乎也能找到一点类似的苗头。比如说我们要求2 ** 7。我们知道,它可以拆分成2**4 * 2 ** 3。而对于2的4次方可以拆分成2**2 再做一次平方。对于一个理想的数字,比如说这个指数本身就是2的多少次方的,像4, 我们只要求出2**2,然后再在它的基础上求它自己的平方,这不就得到2**4了吗?而对于前面2的7次方这个数字,我们将它拆成了2**4乘以2**3。既然2的4次方可以这样来求,对于这些奇数来说呢?实际上我们可以进一步拆分的。2**3可以拆成2**2再乘以2。这样,它们都可以被拆分成一系列的数字给拼起来。

    这个过程似乎带来了一点思路。我们针对前面这个问题再深入的看一下,假设我们要求的这个数字的指数是7, 对它的拆分是将它拆成4和3。因为4是最接近它的2的指数,它可以通过两次求平方运算得到。而3呢,则需要再进一步按照这种方式来拆分。相当于找最接近它的那个2的指数。那么怎么找这个最接近它本身的2的指数呢?我们看7它的二进制表示形式:

快速指数运算方法_第1张图片

    这是一个用4个位来表示的二进制数字形式。其实,它最接近的那个2的指数不就是为1的最高位么?那么我们最终要凑这个数字无非就是将各个为1的位对应的数字加起来。既然我们的目的是将这些数字加起来,那么完全就没有必要考虑从左往右的拆分了。我们完全可以这样来求,看一个数字的某个位是否为1,比如我们从最低位开始。如果是1, 我们就加上一个2**0,也就是1, 而再往后一位对应的第1位的位置, 如果这一个位置也是1的话,则加一个2**1的值。后面依次类推,对应的分别是2**2, 2**3等等。我们这里相当于得出来一个根据二进制数字来凑这个整数的过程。类似的伪代码如下:

 

while(b != 0) {  // 如果这个要凑的数字不为0
    r = b % 2;    // 取出最低位
    b /= 2;
    k = 0;
    if(r == 1)
        sum += (2 ** k)  // 如果当前的位为1, 则表示当前位有数字,加上对应的2的k次方。
    k++;
}

    这里的过程是根据2进制数来求对应整数的。而我们这里具体的问题是要求对应的指数。这里也比较简单,我们对应的更加高的一个位相当于原来的数字对它本身做了个平方。而如果对应的这个位为1, 我们就将它乘以原来的一个数字,相当于前面的相加变相乘。所以套用这个框架,我们就可以很容易得到如下的代码:

 

public static long calculate(int a, int b) {
        if(a <= 0 || b <= 0)
            throw new IllegalArgumentException();
        int r;
        long x = a;
        long y = 1;
        while(b != 0) {
            r = b % 2;
            b = b / 2;
            if(r == 1)
                y *= x;
            x = x * x;
        }
        return y;
    }
     这里结合了两个部分,首先要将这个数字拆分成对应的2进制的形式。不过不需要完整的拆开,每次用b % 2得到的就是当前最低位的值。而将b / 2则相当于b的二进制数字向右移动一位,将第二位的数字变成最低位。这样每移一次我们将对应的底数求平方,就对应到这个位的值。

    从时间复杂度的角度来看,因为我们在循环里每次对这个指数向右移动1位。这个数字对应的是logN个二进制位。所以我们总共遍历的次数为logN。也就是时间复杂度为O(logN)。

 

总结

    求指数函数的快速方法无非是用到了几个典型的数学特性。我们知道,对一个数的指数求平方的话,相当于对指数乘以2。所以问题的根源就归结到我们要怎么样来凑出这个给定的指数值。而如果我们对于数字的2进制表示形式比较清楚的话,会发现它们可以归结到一个将数字转化成二进制的样式的问题。

 

参考材料

Mathematics for computer science

你可能感兴趣的:(algorithms,java)