计算汉明权重的SWAR(SIMD within a Register)算法

前言

在很久之前,笔者在《布隆过滤器(Bloom Filter)原理及Guava中的具体实现》这篇文章的最后,说到JDK中提供了计算整形数二进制表示中1的数量[即所谓汉明权重(Hamming weight)]的方法,并且说它是Java语言中最强的骚操作之一。本文来简单探究一下骚操作背后的思路。

朴素的SWAR

Integer.bitCount()方法的源码中有一句注释。

// HD, Figure 5-2

说明该方法的原理可以在《Hacker's Delight》这本书的第5章找到。很显然,计算二进制串汉明权重的问题可以转化为计算所有1之和的问题,以32位整数为例,分治的步骤如下。

计算汉明权重的SWAR(SIMD within a Register)算法_第1张图片
  1. 将每2个比特视为一组,一共16组,计算每组中有多少个1。
    利用0x5555555501010101010101010101010101010101)作为掩码。将原数i与该掩码做逻辑与运算,可以取得每组中低位的比特;将原数i右移一位再与该掩码做逻辑与运算,可以取得每组中高位的比特。将两者相加,即可得到每两个比特中的1数量。
i = (i & 0x55555555) + ((i >> 1) & 0x55555555)
  1. 我们已经得到了16组范围在[0, 2]之间的结果,接下来在此基础上将每4个比特(即上一步的每两组)视为一组,一共8组,计算每组中有多少个1。
    利用0x3333333300110011001100110011001100110011)作为掩码。将原数i与该掩码做逻辑与运算,可以取得每组中低2位的比特;将原数i右移2位再与该掩码做逻辑与运算,可以取得每组中高2位的比特。将两者相加,即可得到每4个比特中的1数量。
i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
  1. 我们已经得到了8组范围在[0, 4]之间的结果,接下来在此基础上将每8个比特(即上一步的每两组)视为一组,一共4组,计算每组中有多少个1。
    如法炮制,利用0x0F0F0F0F00001111000011110000111100001111)作为掩码,继续做逻辑与、右移和相加操作即可。
i = (i & 0x0F0F0F0F) + ((i >> 4) & 0x0F0F0F0F)
  1. 继续按照分治思想,两两合并,得出最终的结果。写出完整的方法如下。
public static int bitCount(int i) {
    i = (i & 0x55555555) + ((i >> 1) & 0x55555555);
    i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
    i = (i & 0x0F0F0F0F) + ((i >> 4) & 0x0F0F0F0F);
    i = (i & 0x00FF00FF) + ((i >> 8) & 0x00FF00FF);
    i = (i & 0x0000FFFF) + ((i >> 16) & 0x0000FFFF);
    return i;
}

上述思路一般被称为SWAR(SIMD within a Register),顾名思义,就是在一个寄存器上进行单指令多数据流(single instruction, multiple data)操作,在很多地方都有应用。可见,数值i确实可以只用单个寄存器来存储,不需要额外的存储空间。并且上述方法执行的都是位运算和加法操作,现代处理器对此都有特殊的优化,效率非常高,并且还消灭了相对比较耗时的分支和跳转操作。

这样一看,我们刷题时容易想到的O(n)时间复杂度的解法(即while (i > 0) { i = i & (i - 1); bitCount++; })简直弱爆了。

优化的SWAR

说完了SWAR,下面来看看真正的Integer.bitCount()实现。

public static int bitCount(int i) {
    // HD, Figure 5-2
    i = i - ((i >>> 1) & 0x55555555);
    i = (i & 0x33333333) + ((i >>> 2) & 0x33333333);
    i = (i + (i >>> 4)) & 0x0f0f0f0f;
    i = i + (i >>> 8);
    i = i + (i >>> 16);
    return i & 0x3f;
}

似乎与上面讲的思路有很大出入?实际上仍然是由朴素的SWAR算法优化而来的。

  1. 第一步将每两个比特视为一组时,容易得知有对应关系:
00 = 00 - 00(0个1)
01 = 01 - 00 = 10 - 01(1个1)
10 = 11 - 01(2个1)

亦即对于两个比特组成的数i,汉明权重就是i - (i >>> 1)(其中>>>表示无符号右移)。那么扩展考虑,如果i不止两个比特,按照朴素SWAR算法,在无符号右移之后再与掩码0x55555555做逻辑与,就可以消除掉右移对2比特组的高位的影响(因为i >>> 1的高位不可能为1),i - (i >>> 1)的关系仍然成立,即:

i = i - ((i >>> 1) & 0x55555555)

这样就减少了一次位运算。

  1. 第二步将每4个比特视为一组,由于i + (i >>> 2)可能会产生进位,似乎没什么优化思路,于是维持原状。

  2. 第三步实际上就是计算每个字节的汉明权重,由于每个字节最多只有8个1(计数值为1000),所以先做加法i + (i >>> 4)再做按位与,不会受到进位的影响,能够保证结果的正确性,即:

i = (i + (i >>> 4)) & 0x0f0f0f0f

这样又减少了一次位运算。

  1. 后面的步骤可以视为计算每2个字节、每4个字节的汉明权重,所以思路与第三步相同,可以先做加法再做按位与,即:
i = (i + (i >>> 8)) & 0x00ff00ff
i = (i + (i >>> 16)) & 0x0000ffff

但是,我们已经知道Integer的汉明权重肯定不会超过32,即100000,所以实际上只需要在最后与0x3F(即111111)做按位与,就能得到最终的结果了,又减少了一次位运算。

i = i + (i >>> 8)
i = i + (i >>> 16)
i = i & 0x3f

根据Hacker's Delight书中的说法,优化的SWAR汉明权重算法只需要21条指令就可以执行完毕,确实非常精妙了。

The End

Happy Saturday night.

民那晚安晚安。

你可能感兴趣的:(算法/数据结构)