Byte Pair Encoding(BPE)算法及代码笔记

Byte Pair Encoding(BPE)算法

BPE算法是Transformer中构建词表的方法,大致分为如下几个步骤:

  1. 将语料中的文本切分为字符
  2. 统计高频共现二元组
  3. 将共现频率最高的二元组合并加入词表
  4. 重复上述第二和第三直到词表规模达到预先设置的数量,或没有可以合并的二元组为止

以GPT-2中BPE相关的代码为例对代码进行整理

完整代码如下所示

"""
BPE算法:字节对编码算法,将任意UTF-8字符串转换为整数索引序列,方便后续的神经网络运算。

bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into
sequences of integers, where each integer represents small chunks of commonly
occuring characters. This implementation is based on openai's gpt2 encoder.py:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
but was mildly modified because the original implementation is a bit confusing.
I also tried to add as many comments as possible, my own understanding of what's
going on.
"""

import os
import json
import regex as re
import requests

import torch

# -----------------------------------------------------------------------------

def bytes_to_unicode():
    """
    将字节(8bit->2**8->256个)转换为unicode表示的字符。
    有些字节表示的字符太"丑"了,比如chr(0)为'\x00',OpenAI选择进行额外的转换。
    
    Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode
    character that represents it visually. Some bytes have their appearance preserved
    because they don't cause any trouble. These are defined in list bs. For example:
    chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".
    However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these
    bytes, into new characters in a range where chr() returns a single nice character.
    So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).
    In particular, the space character is 32, which we can see by ord(' '). Instead,
    this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.
    So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters
    that "look nice", either in their original form, or a funny shifted character
    like 'Ā', or 'Ġ', etc.
    """
    # the 188 integers that render fine in their original form and need no shifting
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    n = 0
    for b in range(2**8):
        if b not in bs:
            # if this byte is "ugly" then map it to the next available "nice" character
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    d = dict(zip(bs, cs))
    return d

def get_pairs(word):
    """
    获取一个单词中所有可能的字符二元组
    
    Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:

    def __init__(self, encoder, bpe_merges):
        # byte encoder/decoder
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        # bpe token encoder/decoder
        self.encoder = encoder  # 将字符串转换为整数索引
        self.decoder = {v:k for k,v in self.encoder.items()}  # 将整数索引转换为字符串
        # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        # the splitting pattern used for pre-tokenization
        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment
        """
        ok so what is this regex looking for, exactly?
        python re reference: https://docs.python.org/3/library/re.html
        - the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right
        - '\'s' would split up things like Andrej's -> (Andrej, 's)
        - ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter"
        - ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number"
        - ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number
        - '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace
                       so this will consume whitespace characters in a sequence but exclude the last whitespace in
                       that sequence. that last whitespace has the opportunity to then match the optional ' ?' in
                       earlier patterns.
        - '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string
        So TLDR:
        - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens
        - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces
        """
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")  # 预先使用一些正则表达式提前将字符串切分,例如将字符串划分为连续的字母、数字、空格和其他字符。包括一些英文的规则。
        self.cache = {}

    def bpe(self, token):
        """
        对每个预先切分出来的token进行进一步的bpe切分,切分主要依赖于预先统计的bpe_ranks;
        bpe_ranks: 从大规模语料中统计的bi-gram共现频率
        
        this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens
        up the tree. token is a string of one individual 'word' (after regex tokenization)
        and after byte encoding, e.g. 'Ġthere'.
        """
        # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'

        # memoization, for efficiency
        if token in self.cache:  # cache缓存加速bpe算法
            return self.cache[token]

        word = tuple(token) # individual characters that make up the token, in a tuple
        pairs = get_pairs(word) # get all bigrams

        if not pairs:
            return token

        while True:

            # find the next lowest rank bigram that can be merged
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组
            if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低
                break # no more bigrams are eligible to be merged
            first, second = bigram

            # we will now replace all occurences of (first, second) in the list of current
            # words into one merged token first_second, in the output list new_words
            new_word = []
            i = 0
            while i < len(word):  # 合并二元组(考虑多次出现的情况)

                # find the next occurence of first in the sequence of current words
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                # if this occurence is also followed by second, then merge them into one
                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1

            # all occurences of (first, second) have been merged to first_second
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)

        # concat all words into a string, and use ' ' as the separator. Note that
        # by now all characters have been byte encoded, guaranteeing that ' ' is
        # not used in the actual data and is a 'special' delimiter character
        word = ' '.join(word)

        # cache the result and return
        self.cache[token] = word
        return word

    def encode(self, text):
        """ 
        字符串序列转整数索引序列
        
        string goes in, list of integers comes out
        """
        bpe_idx = []
        
        # pre-tokenize the input text into string tokens (words, roughly speaking)
        tokens = re.findall(self.pat, text)  # 预先使用正则表达式粗糙切分
        
        # process each token into BPE integers
        for token in tokens:  # 每个token内部使用bpe不断合并二元组
            # encode the token as a bytes (b'') object
            token_bytes = token.encode('utf-8')
            # translate all bytes to their unicode string representation and flatten
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            # perform all the applicable bpe merges according to self.bpe_ranks
            token_merged = self.bpe(token_translated).split(' ')
            # translate all bpe tokens to integers
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            # extend our running list of all output integers
            bpe_idx.extend(token_ix)
        return bpe_idx

    def encode_and_show_work(self, text):
        """ debugging function, same as encode but returns all intermediate work """
        bpe_idx = []
        parts = []
        tokens = re.findall(self.pat, text)
        for token in tokens:
            token_bytes = token.encode('utf-8')
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            token_merged = self.bpe(token_translated).split(' ')
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            bpe_idx.extend(token_ix)
            parts.append({
                'token': token,
                'token_bytes': token_bytes,
                'token_translated': token_translated,
                'token_merged': token_merged,
                'token_ix': token_ix,
            })
        out = {
            'bpe_idx': bpe_idx, # the actual output sequence
            'tokens': tokens, # result of pre-tokenization
            'parts': parts, # intermediates for each token part
        }
        return out

    def decode(self, bpe_idx):
        """ 
        整数索引序列恢复成字符串序列
        
        list of integers comes in, string comes out 
        """
        # inverse map the integers to get the tokens
        tokens_merged = [self.decoder[token] for token in bpe_idx]
        # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
        tokens_flat = ''.join(tokens_merged)
        tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
        # recover the full utf-8 string
        text = tokens_bytes.decode('utf-8', errors='replace')
        return text

def get_file(local_file, remote_file):
    """ downloads remote_file to local_file if necessary """
    if not os.path.isfile(local_file):
        print(f"downloading {remote_file} to {local_file}")
        response = requests.get(remote_file)
        open(local_file, "wb").write(response.content)

def get_encoder():
    """
    从OpenAI官方的GPT-2分词器cache文件初始化
    
    Returns an instance of the GPT BPE Encoder/Decoder
    and handles caching of "database" files.
    """
    home_dir = os.path.expanduser('~')
    cache_dir = os.path.join(home_dir, '.cache', 'mingpt')
    os.makedirs(cache_dir, exist_ok=True)

    # load encoder.json that has the raw mappings from token -> bpe index
    encoder_local_file = os.path.join(cache_dir, 'encoder.json')
    encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json'
    get_file(encoder_local_file, encoder_remote_file)
    with open(encoder_local_file, 'r') as f:
        encoder = json.load(f)
    assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token

    # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure
    # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab
    vocab_local_file = os.path.join(cache_dir, 'vocab.bpe')
    vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
    get_file(vocab_local_file, vocab_remote_file)
    with open(vocab_local_file, 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    # light postprocessing: strip the version on first line and the last line is a blank
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    assert len(bpe_merges) == 50000 # 50,000 merged tokens

    # construct the Encoder object and return
    enc = Encoder(encoder, bpe_merges)
    return enc

# -----------------------------------------------------------------------------

class BPETokenizer:
    """ PyTorch-aware class that wraps the Encoder above """

    def __init__(self):
        self.encoder = get_encoder()

    def __call__(self, text, return_tensors='pt'):
        # PyTorch only; here because we want to match huggingface/transformers interface
        assert return_tensors == 'pt'
        # single string input for now, in the future potentially a list of strings
        assert isinstance(text, str)
        # encode and create a "batch dimension" of 1
        idx = [self.encoder.encode(text)]
        # wrap into PyTorch tensor
        out = torch.tensor(idx, dtype=torch.long)
        return out

    def decode(self, idx):
        # ensure a simple 1D tensor for now
        assert idx.ndim == 1
        # decode indices to text
        text = self.encoder.decode(idx.tolist())
        return text

从Encoder类中bpe方法出发,理解BPE的全过程,以下为bpe方法代码:

def bpe(self, token):
	
	# cache缓存加速bpe算法
	if token in self.cache:  
	    return self.cache[token]
	
	word = tuple(token) # individual characters that make up the token, in a tuple
	pairs = get_pairs(word) # get all bigrams
	
	if not pairs:
	    return token
	
	while True:
	
	    # find the next lowest rank bigram that can be merged
	    bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组
	    if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低
	        break # no more bigrams are eligible to be merged
	    first, second = bigram
	
	    # we will now replace all occurences of (first, second) in the list of current
	    # words into one merged token first_second, in the output list new_words
	    new_word = []
	    i = 0
	    while i < len(word):  # 合并二元组(考虑多次出现的情况)
	
	        # find the next occurence of first in the sequence of current words
	        try:
	            j = word.index(first, i)
	            new_word.extend(word[i:j])
	            i = j
	        except:
	            new_word.extend(word[i:])
	            break
	
	        # if this occurence is also followed by second, then merge them into one
	        if word[i] == first and i < len(word)-1 and word[i+1] == second:
	            new_word.append(first+second)
	            i += 2
	        else:
	            new_word.append(word[i])
	            i += 1
	
	    # all occurences of (first, second) have been merged to first_second
	    new_word = tuple(new_word)
	    word = new_word
	    if len(word) == 1:
	        break
	    else:
	        pairs = get_pairs(word)
	
	# concat all words into a string, and use ' ' as the separator. Note that
	# by now all characters have been byte encoded, guaranteeing that ' ' is
	# not used in the actual data and is a 'special' delimiter character
	word = ' '.join(word)
	
	# cache the result and return
	self.cache[token] = word
	return word

以下是对bpe方法代码分块进行解读:

"""
在Encoder类中初始化一个缓存空间,在每次对token进行bpe操作时先验证缓存空间中是否包含,若有包含则直接结束。
"""
# cache缓存加速bpe算法
if token in self.cache:  
    return self.cache[token]
"""
将输入bpe方法的token进行切分,此时输入的token是一个已将文本切分后的单词,使用tuple对单词中所有字符进行拆分形成一个包含token中所有字符的元组。
"""
word = tuple(token) # individual characters that make up the token, in a tuple
"""
使用get_pairs函数通过对已经拆分好的token字符元组获取所有可能的字符二元组
"""
pairs = get_pairs(word) # get all bigrams
"""
输入的word是token中所有字符的有序元组,从元组中的第一个字符开始,每两个相邻的字符组成一个二元组
"""
def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs
"""
判断输入的token是否产生了二元组,若没有产生二元组则结束
"""
if not pairs:
	return token
"""
找到生成的二元组中共现频率最高的,其中使用bpe_ranks获得二元组频率排名,通过排名找到排名最小也就是频率最高的二元组
"""
# find the next lowest rank bigram that can be merged
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))  # 优先合并共现频率高的二元组       
"""
形成二元组对应共现频率的字典,其中bpe_merges是从已经统计好的文件中读取二元组频率数据
"""
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
"""
读取的文件中每行是一个二元组,行号即为频率,行号越小频率越高
"""
vocab_local_file = os.path.join(cache_dir, 'vocab.bpe')
vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
get_file(vocab_local_file, vocab_remote_file)
with open(vocab_local_file, 'r', encoding="utf-8") as f:
    bpe_data = f.read()
# light postprocessing: strip the version on first line and the last line is a blank
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
"""
bpe_ranks中不存在的频率过低的二元组直接跳过,first代表二元组中的第一个字符,second代表二元组中第二个字符
"""
if bigram not in self.bpe_ranks:  # 如果剩下的二元组共现频率过低
	break # no more bigrams are eligible to be merged
first, second = bigram
"""
此部分代码是将token中所有的字符和最高频率二元组加入到new_word列表中
"""
# we will now replace all occurences of (first, second) in the list of current
# words into one merged token first_second, in the output list new_words
new_word = []
i = 0
while i < len(word):  # 合并二元组(考虑多次出现的情况)

    # find the next occurence of first in the sequence of current words
    try:
        j = word.index(first, i)
        new_word.extend(word[i:j])
        i = j
    except:
        new_word.extend(word[i:])
        break

    # if this occurence is also followed by second, then merge them into one
    if word[i] == first and i < len(word)-1 and word[i+1] == second:
        new_word.append(first+second)
        i += 2
    else:
        new_word.append(word[i])
        i += 1
"""
如果新生成的字符只有一个则直接退出,如果有多个则获得新的字符对继续执行
"""
# all occurences of (first, second) have been merged to first_second
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
    break
else:
    pairs = get_pairs(word)
"""
最后将字符通过空格连接为一个字符串,并存入缓存中
"""
word = ' '.join(word)

# cache the result and return
self.cache[token] = word

本文以GPT-2中的BPE代码为例,主要记录了其中Encoder类里的bpe方法相关代码的阅读笔记

你可能感兴趣的:(LLM,LLM)