首先介绍一下二进制流。
假如有下4个数值
[ 47, 19, 38, 53 ]
首先每个数字对应的二进制分别如下
十进制数值 | 二进制数值 |
---|---|
47 | 0x0010 1111 |
19 | 0x0001 0011 |
38 | 0x0010 0110 |
53 | 0x0011 0101 |
我们需要将这些数字保存到一个二进制文件中。 注:这里不考虑BigEndian还是LittleEndian。
那么这个文件中,按照输入的顺序,即 [ 47, 19, 38, 53 ]
来说,文件应该是这样的,
(53) 0x0011 0101 (38) 0x0010 0110 (19) 0x0001 0011 (47) 0x0010 1111
一共4个byte。
如果你发现,其实每个数值的前2位都是0,如果去掉这两位,剩下有效的数据是24bit,就是3个byte。
这样就剩下了一个byte。
BitPacker是用于做这种节省工作的,它将数值按指定的bit位数存入一个二进制流中。
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self._current_value += (value << self._current_bits)
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xff
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
以及对应的解包类
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> tp.Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
return out
下面是测试样例,
if __name__ == '__main__':
length: int = 4
bits: int = 6
tokens: tp.List[int] = [ 47, 19, 38, 53 ]
rebuilt: tp.List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
下图画了两个byte的bit位置,当第一个数 (47) 0x0010 1111push进来,因为不足一个byte长度,所以不做特殊处理。
输入第二个字符0x0001 0011,直观上需要把它放到上一个数字的左边;
那么实际上就是将它向左移动6个bit, 即 0x0001 0011 <<6, 如下图
此时长度已经大于一个byte,所以可以将低位byte打包起来,即 0x1110 1111,剩下数据如图
此时长度大于一个byte,所以可以将低位byte打包起来,即 0x0110 0100,剩下输入如图
![在这里插入图片描述](https://img-blog.csdnimg.cn/6aace43f6d4b49c6a3771beef6670f8c.png
同样,按照上面的逻辑,打包最低位byte, 0x11010110
压入的数值一定要大于或者等于BitPacker中指定的bit长度,否则将会被数值截断。