文章来源 | 恒源云社区
原文地址 | BPE 算法详解
原文作者 | Mathor
Byte Pair Encoding
在NLP模型中,输入通常是一个句子,例如"I went to New York last week."
,一句话中包含很多单词(token)。传统的做法是将这些单词以空格进行分隔,例如['i', 'went', 'to', 'New', 'York', 'last', 'week']
。然而这种做法存在很多问题,例如模型无法通过old, older, oldest
之间的关系学到smart, smarter, smartest
之间的关系。如果我们能使用将一个token分成多个subtokens,上面的问题就能很好的解决。本文将详述目前比较常用的subtokens算法——BPE(Byte-Pair Encoding)
现在性能比较好一些的NLP模型,例如GPT、BERT、RoBERTa等,在数据预处理的时候都会有WordPiece的过程,其主要的实现方式就是BPE(Byte-Pair Encoding)。具体来说,例如['loved', 'loving', 'loves']
这三个单词。其实本身的语义都是"爱"的意思,但是如果我们以词为单位,那它们就算不一样的词,在英语中不同后缀的词非常的多,就会使得词表变的很大,训练速度变慢,训练的效果也不是太好。BPE算法通过训练,能够把上面的3个单词拆分成["lov","ed","ing","es"]
几部分,这样可以把词的本身的意思和时态分开,有效的减少了词表的数量。算法流程如下:
- 设定最大subwords个数
- 将所有单词拆分为单个字符,并在最后添加一个停止符
,同时标记出该单词出现的次数。例如,
"low"
这个单词出现了5次,那么它将会被处理为{'l o w ': 5}
- 统计每一个连续字节对的出现频率,选择最高频者合并成新的subword
- 重复第3步直到达到第1步设定的subwords词表大小或下一个最高频的字节对出现频率为1
例如
{'l o w ': 5, 'l o w e r ': 2, 'n e w e s t ': 6, 'w i d e s t ': 3}
出现最频繁的字节对是** e
和s
**,共出现了6+3=9次,因此将它们合并
{'l o w ': 5, 'l o w e r ': 2, 'n e w es t ': 6, 'w i d es t ': 3}
出现最频繁的字节对是** es
和t
**,共出现了6+3=9次,因此将它们合并
{'l o w ': 5, 'l o w e r ': 2, 'n e w est ': 6, 'w i d est ': 3}
出现最频繁的字节对是** est
和 **,共出现了6+3=9次,因此将它们合并
{'l o w ': 5, 'l o w e r ': 2, 'n e w est': 6, 'w i d est': 3}
出现最频繁的字节对是** l
和o
**,共出现了5+2=7次,因此将它们合并
{'lo w ': 5, 'lo w e r ': 2, 'n e w est': 6, 'w i d est': 3}
出现最频繁的字节对是** lo
和w
**,共出现了5+2=7次,因此将它们合并
{'low ': 5, 'low e r ': 2, 'n e w est': 6, 'w i d est': 3}
…继续迭代直到达到预设的subwords词表大小或下一个最高频的字节对出现频率为1。这样我们就得到了更加合适的词表,这个词表可能会出现一些不是单词的组合,但是其本身有意义的一种形式
停止符的意义在于表示subword是词后缀。举例来说:
st
不加可以出现在词首,如
st ar
;加了表明改字词位于词尾,如
wide st
,二者意义截然不同
BPE实现
import re, collections
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding='utf-8') as fhand:
for line in fhand:
words = line.strip().split()
for word in words:
vocab[' '.join(list(word)) + ' '] += 1
return vocab
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?': 5, 'l o w e r ': 2, 'n e w e s t ': 6, 'w i d e s t ': 3}
# Get free book from Gutenberg
# wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt
# vocab = get_vocab('pg16457.txt')
print('==========')
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')
num_merges = 5
for i in range(num_merges):
pairs = get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
print('Iter: {}'.format(i))
print('Best pair: {}'.format(best))
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))
print('==========')
输出如下
==========
Tokens Before BPE
Tokens: defaultdict(, {'l': 7, 'o': 7, 'w': 16, '': 16, 'e': 17, 'r': 2, 'n': 6, 's': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 0
Best pair: ('e', 's')
Tokens: defaultdict(, {'l': 7, 'o': 7, 'w': 16, '': 16, 'e': 8, 'r': 2, 'n': 6, 'es': 9, 't': 9, 'i': 3, 'd': 3})
Number of tokens: 11
==========
Iter: 1
Best pair: ('es', 't')
Tokens: defaultdict(, {'l': 7, 'o': 7, 'w': 16, '': 16, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 2
Best pair: ('est', '')
Tokens: defaultdict(, {'l': 7, 'o': 7, 'w': 16, '': 7, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 10
==========
Iter: 3
Best pair: ('l', 'o')
Tokens: defaultdict(, {'lo': 7, 'w': 16, '': 7, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========
Iter: 4
Best pair: ('lo', 'w')
Tokens: defaultdict(, {'low': 7, '': 7, 'e': 8, 'r': 2, 'n': 6, 'w': 9, 'est': 9, 'i': 3, 'd': 3})
Number of tokens: 9
==========
编码和解码
编码
在之前的算法中,我们已经得到了subword的词表,对该词表按照字符个数由多到少排序。编码时,对于每个单词,遍历排好序的子词词表寻找是否有token是当前单词的子字符串,如果有,则该token是表示单词的tokens之一
我们从最长的token迭代到最短的token,尝试将每个单词中的子字符串替换为token。 最终,我们将迭代所有tokens,并将所有子字符串替换为tokens。 如果仍然有子字符串没被替换但所有token都已迭代完毕,则将剩余的子词替换为特殊token,如
例如
# 给定单词序列
["the", "highest", "mountain"]
# 排好序的subword表
# 长度 6 5 4 4 4 4 2
["errrr", "tain", "moun", "est", "high", "the", "a"]
# 迭代结果
"the" -> ["the"]
"highest" -> ["high", "est"]
"mountain" -> ["moun", "tain"]
解码
将所有的tokens拼在一起即可,例如
# 编码序列
["the", "high", "est", "moun", "tain"]
# 解码序列
"the highest mountain"
编码和解码实现
import re, collections
def get_vocab(filename):
vocab = collections.defaultdict(int)
with open(filename, 'r', encoding='utf-8') as fhand:
for line in fhand:
words = line.strip().split()
for word in words:
vocab[' '.join(list(word)) + ' '] += 1
return vocab
def get_stats(vocab):
pairs = collections.defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols)-1):
pairs[symbols[i],symbols[i+1]] += freq
return pairs
def merge_vocab(pair, v_in):
v_out = {}
bigram = re.escape(' '.join(pair))
p = re.compile(r'(?':
return len(token[:-4]) + 1
else:
return len(token)
def tokenize_word(string, sorted_tokens, unknown_token=''):
if string == '':
return []
if sorted_tokens == []:
return [unknown_token]
string_tokens = []
for i in range(len(sorted_tokens)):
token = sorted_tokens[i]
token_reg = re.escape(token.replace('.', '[.]'))
matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
if len(matched_positions) == 0:
continue
substring_end_positions = [matched_position[0] for matched_position in matched_positions]
substring_start_position = 0
for substring_end_position in substring_end_positions:
substring = string[substring_start_position:substring_end_position]
string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
string_tokens += [token]
substring_start_position = substring_end_position + len(token)
remaining_substring = string[substring_start_position:]
string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
break
return string_tokens
# vocab = {'l o w ': 5, 'l o w e r ': 2, 'n e w e s t ': 6, 'w i d e s t ': 3}
vocab = get_vocab('pg16457.txt')
print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')
num_merges = 10000
for i in range(num_merges):
pairs = get_stats(vocab)
if not pairs:
break
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)
print('Iter: {}'.format(i))
print('Best pair: {}'.format(best))
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')
# Let's check how tokenization will be for a known word
word_given_known = 'mountains'
word_given_unknown = 'Ilikeeatingapples!'
sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]
print(sorted_tokens)
word_given = word_given_known
print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
print('Tokenization of the known word:')
print(vocab_tokenization[word_given])
print('Tokenization treating the known word as unknown:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token=''))
else:
print('Tokenizating of the unknown word:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token=''))
word_given = word_given_unknown
print('Tokenizing word: {}...'.format(word_given))
if word_given in vocab_tokenization:
print('Tokenization of the known word:')
print(vocab_tokenization[word_given])
print('Tokenization treating the known word as unknown:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token=''))
else:
print('Tokenizating of the unknown word:')
print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token=''))
输出如下
Tokenizing word: mountains...
Tokenization of the known word:
['mountains']
Tokenization treating the known word as unknown:
['mountains']
Tokenizing word: Ilikeeatingapples!...
Tokenizating of the unknown word:
['I', 'like', 'ea', 'ting', 'app', 'l', 'es!']