彻底搞懂BPE(Byte Pair Encode)原理(附代码实现)

Byte Pair Encoding

既然你查到这了,就不解释BPE是干啥的了,直接上原理!

核心思想

迭代合并出现频率高的字符对。

例子

1.准备一个语料库(corpus),并统计这个语料库中每个词语的词频,通过“[词频]词语_”的形式存储,这里的“_”表示词语结尾。

注:“er_”和“er”意思不同,“er_”只能放在结尾,组成“newer_”等,“er”则不表示结尾,可以组成“era”等。

Corpus:
[5] low_
[2] lowest_
[6] newer_
[3] wider_
[2] new_

2.设置token词表的大小,或者循环的次数,作为终止条件。
3.统计每个字符出现的次数,结尾“_”的次数也要统计。这张表就是vocabulary,后续迭代结束之后就会利用这张表进行分词。

字符 频次
_ 18
d 3
e 19
i 3
l 7
n 8
o 7
r 9
s 2
t 2
w 22

4.选择两个连续的字符(序)进行合并,并且合并后有着最高的频次。第一次迭代,选择“r”和“_”,组合成“r_”,总共有 6 + 3 = 9次。然后我们将“r_”和它的频次加入vocabulary,并且减去“r”和“_”的次数。

字符 频次
_ 9
d 3
e 19
i 3
l 7
n 8
o 7
r 0
s 2
t 2
w 22
r_ 9

此时,“r”的频次变成0,说明“r”的出现一定会与“_”关联,也就是说“r”一定是最后一个单词。这个时候可以把“r”从词表里删除,词表由于增加了一个“r_”,减少了一个“r”,所以长度不变。(这里也是为什么大家总说词表一般情况是先增加后减少)

5.接下来合并“e”和“r_”,因为“er_”总共出现了 6 + 3 = 9次是当前频次最高的,同样更新词表。增加了“er_”,减少了“r_”,所以词表长度不变。

字符 频次
_ 9
d 3
e 10
i 3
l 7
n 8
o 7
s 2
t 2
w 22
er_ 9

6.然后合并“ew”,共出现8次,更新后的表(_, d, e, i, l, n, o, s, t, w, er_, ew)。这次“ew”没有消除所有的“e”或“w”,也就是说“e”或“w”除了出现在“ew”中还会出现在别的地方,比如“wider”中的“e”和“w”就是分开的。所以词表的长度增加了1。

7.接着就是“new”,总共8次,更新后的表(_, d, e, i, l, o, s, t, w, er_, new)。此时的“new”消除了所有的“n”和“ew”,也就是“n”和“ew”只会出现在“new”里面。这时词表增加了一个“new”但消除了两个,所以词表的长度减少了1。

8.假设我设置循环四次后终止,那么此时的词表就是(_, d, e, i, l, o, s, t, w, er_, new)。

9.根据上述描述,也就可以发现,词表长度的变化总共有三种,+1、-1、不变。

代码

分为两个主要函数,一个专门统计vocabulary,另一个负责合并字符串。

统计词频很暴力,就是遍历vocabulary里每一个元素,利用像两个元素的滑动窗口,挨个组合并且记录频次。找到频次最高的之后进行合并并且输出新的vocabulary。

接下来就是代码,一行行看一定能看懂。

统计词频

import re, collections

text = "The aims for this subject is for students to develop an understanding of the main algorithms used in naturallanguage processing, for use in a diverse range of applications including text classification, machine translation, and question answering. Topics to be covered include part-of-speech tagging, n-gram language modelling, syntactic parsing and deep learning. The programming language used is Python, see for more information on its use in the workshops, assignments and installation at home."
# text = 'low '*5 +'lower '*2+'newest '*6 +'widest '*3

'''
先统计词频
'''
def get_vocab(text):
    
    # 初始化为 0
    vocab = collections.defaultdict(int)
    # 去头去尾再根据空格split
    for word in text.strip().split():
        #note: we use the special token  (instead of underscore in the lecture) to denote the end of a word
        # 给list中每个元素增加空格,并在最后增加结束符号,同时统计单词出现次数
        vocab[' '.join(list(word)) + ' '] += 1
    return vocab
print(get_vocab(text))

彻底搞懂BPE(Byte Pair Encode)原理(附代码实现)_第1张图片
统计相邻字符对的频率

"""
这个函数遍历词汇表中的所有单词,并计算彼此相邻的一对标记。

EXAMPLE:
    word = 'T h e <\w>'
    这个单词可以两两组合成: [('T', 'h'), ('h', 'e'), ('e', '<\w>')]
    
输入:
    vocab: Dict[str, int]  # vocab统计了词语出现的词频
    
输出:
    pairs: Dict[Tuple[str, str], int] # 字母对,pairs统计了单词对出现的频率
"""
def get_stats(vocab):
    pairs = collections.defaultdict(int)
    
    for word,freq in vocab.items():
        
        # 遍历每一个word里面的symbol,去凑所有的相邻两个内容
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[(symbols[i],symbols[i+1])] += freq

    return pairs

开始合并高频字符对

"""
EXAMPLE:
    word = 'T h e <\w>'
    pair = ('e', '<\w>')
    word_after_merge = 'T h e<\w>'
    
输入:
    pair: Tuple[str, str] # 需要合并的字符对
    v_in: Dict[str, int]  # 合并前的vocab
    
输出:
    v_out: Dict[str, int] # 合并后的vocab
    
注意:
    当合并word 'Th e<\w>'中的字符对 ('h', 'e')时,'Th'和'e<\w>'字符对不能被合并。
"""
def merge_vocab(pair, v_in):
    v_out = {}
    # 把pair拆开,然后用空格合并起来,然后用\把空格转义
    bigram = re.escape(' '.join(pair))
    # 自定义一个正则规则, (?排除在外
    p = re.compile(r'(? + bigram + r'(?!\S)')
    
    for v in v_in:
        # 遍历当前的vocabulary,找到匹配正则的v时,才用合并的pair去替换变成新的pair new,如果没有匹配上,那就保持原来的。
        # 比如pair当前是'h'和'e',然后遍历vocabulary,找到符合前后都没有东西只有'h\ e'的时候就把他们并在一起变成'he'
        new = p.sub(''.join(pair),v)
        # 然后新的合并的数量就是当前vocabulary里面pair对应的数量
        v_out[new] = v_in[v]
    return v_out

def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
    return tokens


vocab = get_vocab(text)
print("Vocab =", vocab)
print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')

#about 100 merges we start to see common words
num_merges = 100
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    
    # vocabulary里面pair出现次数最高的作为最先合并的pair
    best = max(pairs, key=pairs.get)
    
    # 先给他合并了再说,当然这里不操作也没什么,到merge_vocab里面都一样
    new_token = ''.join(best)
    vocab = merge_vocab(best, vocab)
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    # add new token to the vocab
    tokens[new_token] = pairs[best]
    # deduct frequency for tokens have been merged
    tokens[best[0]] -= pairs[best]
    tokens[best[1]] -= pairs[best]
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')
    print('vocab, ', vocab)

至此我们就讲完了基本的BPE代码,接下来就是我们怎么去实践,从一个最基本的数据预处理,到最后实现一句话的简单分词。

以下是代码:

def get_tokens_from_vocab(vocab):
    tokens_frequencies = collections.defaultdict(int)
    vocab_tokenization = {}
    for word, freq in vocab.items():
        # 看vocabulary里面的token频率,相当于上面的code中的tokens去除freq为0的
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        # vocab和其对应的tokens
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

def measure_token_length(token):
    
    # 如果token最后四个元素是 < / w >
    if token[-4:] == '':
        # 那就返回除了最后四个之外的长度再加上1(结尾)
        return len(token[:-4]) + 1
    else:
        # 如果这个token里面没有结尾就直接返回当前长度
        return len(token)
    
# 如果vocabulary里面找不到要拆分的词,就根据已经有的token现拆
def tokenize_word(string, sorted_tokens, unknown_token=''):
    
    # base case,没词进来了,那拆的结果就是空的
    if string == '':
        return []
    # 已有的sorted tokens没有了,那就真的没这个词了
    if sorted_tokens == []:
        return [unknown_token] * len(string)

    # 记录拆分结果
    string_tokens = []
    
    # iterate over all tokens to find match
    for i in range(len(sorted_tokens)):
        token = sorted_tokens[i]
        
        # 自定义一个正则,然后要把token里面包含句号的变成[.]
        token_reg = re.escape(token.replace('.', '[.]'))
        
        # 在当前string里面遍历,找到每一个match token的开始和结束位置,比如string=good,然后token是o,输出[(2,2),(3,3)]?
        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
        # if no match found in the string, go to next token
        if len(matched_positions) == 0:
            continue
        # 因为要拆分这个词,匹配上的token把这个word拆开了,那就要拿到除了match部分之外的substring,所以这里要拿match的start
        substring_end_positions = [matched_position[0] for matched_position in matched_positions]
        substring_start_position = 0
        
        
        # 如果有匹配成功的话,就会进入这个循环
        for substring_end_position in substring_end_positions:
            # slice for sub-word
            substring = string[substring_start_position:substring_end_position]
            # tokenize this sub-word with tokens remaining 接着用substring匹配剩余的sorted token,因为刚就匹配了一个
            string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
            # 先把sorted token里面匹配上的记下来
            string_tokens += [token]
            substring_start_position = substring_end_position + len(token)
        # tokenize the remaining string 去除前头的substring,去除已经匹配上的,后面还剩下substring_start_pos到结束的一段substring没看
        remaining_substring = string[substring_start_position:]
        # 接着匹配
        string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
        break
    else:
        # return list of unknown token if no match is found for the string
        string_tokens = [unknown_token] * len(string)
        
    return string_tokens

"""
该函数生成一个所有标记的列表,按其长度(第一键)和频率(第二键)排序。

EXAMPLE:
    token frequency dictionary before sorting: {'natural': 3, 'language':2, 'processing': 4, 'lecture': 4}
    sorted tokens: ['processing', 'language', 'lecture', 'natural']
    
INPUT:
    token_frequencies: Dict[str, int] # Counter for token frequency
    
OUTPUT:
    sorted_token: List[str] # Tokens sorted by length and frequency

"""
def sort_tokens(tokens_frequencies):
    # 对 token_frequencies里面的东西,先进行长度排序,再进行频次,sorted是从低到高所以要reverse
    sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item:(measure_token_length(item[0]),item[1]), reverse=True)
    
    # 然后只要tokens不要频次
    sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]

    return sorted_tokens

#display the vocab
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)

#sort tokens by length and frequency
sorted_tokens = sort_tokens(tokens_frequencies)
print("Tokens =", sorted_tokens, "\n")

#print("vocab tokenization: ", vocab_tokenization)

sentence_1 = 'I like natural language processing!'
sentence_2 = 'I like natural languaaage processing!'
sentence_list = [sentence_1, sentence_2]

for sentence in sentence_list:
    
    print('==========')
    print("Sentence =", sentence)
    
    for word in sentence.split():
        word = word + ""

        print('Tokenizing word: {}...'.format(word))
        if word in vocab_tokenization:
            print(vocab_tokenization[word])
        else:
            print(tokenize_word(string=word, sorted_tokens=sorted_tokens, unknown_token=''))

注:如果害没有看懂可以看这篇(也有知乎翻译的)或者这篇中间提到BPE的地方(我借助了这两篇的表现方式来讲解)。详细实现代码可以参考这篇,每行代码都有注释。

优点

  1. Data-informed tokenisation (理解下来可能是这样分出来的词表具有一定数据启示,因为同时根据字符出现的频率来进行分割)
  2. 不同语种通用
  3. 对于未知单词也适用

缺点

就是会产生一些不完整的单词(subword)。

其他

1.实际操作中,BPE会运行成千上万次merge,产生很大的vocabulary
2.经常出现的词会完整的呈现在vocabulary,因为他的词频很高
3.相反,出现少的词就会以subword的形式呈现
4.最糟的情况就是测试集中出现了从来没见过的词(也有可能是Missspelling),那它就会被分成一个个单词。
5.想了解NLP基础理论也可以看看我的博客

你可能感兴趣的:(ai,自然语言处理)