本次总结的是一篇16年的关于NLP中分词操作的论文,论文链接Subword,参考的实现代码subword-nmt,许多论文方法(例如BERT等)都将该方法应用到分词处理上,相对于word-level和character-level,该方法取得了不错的效果。
一个很简单的压缩算法。具体来说,分为以下几步:
这篇论文所提方法很简单,但是代码具体实现可以了解下,挺有趣的。这里参考的代码nmt-subword。代码中主要有以下两个重要文件代码。依次来看看。
vocab = get_vocabulary(infile, is_dict)
vocab = dict([(tuple(x[:-1])+(x[-1]+'',) ,y) for (x,y) in vocab.items()])
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
def get_pair_statistics(vocab):
"""Count frequency of all symbol pairs, and create index"""
# data structure of pair frequencies
stats = defaultdict(int)
#index from pairs to words
indices = defaultdict(lambda: defaultdict(int))
for i, (word, freq) in enumerate(vocab):
prev_char = word[0]
for char in word[1:]:
stats[prev_char, char] += freq
indices[prev_char, char][i] += 1
prev_char = char
return stats, indices
for i in range(num_symbols):
if stats:
most_frequent = max(stats, key=lambda x: (stats[x], x)) ## 统计出现的最高频次的2-gram
# we probably missed the best pair because of pruning; go back to full statistics
if not stats or (i and stats[most_frequent] < threshold):
prune_stats(stats, big_stats, threshold)
stats = copy.deepcopy(big_stats)
most_frequent = max(stats, key=lambda x: (stats[x], x))
# threshold is inspired by Zipfian assumption, but should only affect speed
threshold = stats[most_frequent] * i/(i+10000.0)
prune_stats(stats, big_stats, threshold)
if stats[most_frequent] < min_frequency:
sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency))
break
if verbose:
sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
outfile.write('{0} {1}\n'.format(*most_frequent)) ## 将出现频次最高的2-gram存入到词表中
changes = replace_pair(most_frequent, sorted_vocab, indices) ## 合并和替换出现频次最高的2-gram
update_pair_statistics(most_frequent, changes, stats, indices) ## 更新与合并后新词左右两边相连的2-gram频次
stats[most_frequent] = 0
这样就得到了2-gram词频表 BPE_codes。
在learn_bpe.py中,按照bpe算法得到了数据集的BPE_codes,那么将数据喂给模型之前,我们需要将输入数据按照bpe_codes进行编码,通俗来说就是按照BPE_codes里面的分词规则对输入数据进行分词罢了。
def segment_tokens(self, tokens):
"""segment a sequence of tokens with BPE encoding"""
output = []
for word in tokens:
# eliminate double spaces
if not word:
continue
new_word = [out for segment in self._isolate_glossaries(word)
for out in encode( ## 下面会有encode方法介绍
segment, ## 需要进行编码的序列
self.bpe_codes, ## learn_bpe得到的编码,进行了字典处理,第一个为pair 元祖,第一个为对应的索引
self.bpe_codes_reverse, ## (pair[0] + pair[1], pair)
self.vocab,
self.separator,
self.version,
self.cache,
self.glossaries)] ## 已有的一些词汇表,如城市名称,国家名称等,这些词不能再切分,并且要和其他词split开
for item in new_word[:-1]:
output.append(item + self.separator)
output.append(new_word[-1])
return output
再来看看encode方法,就是将序列内的2-ngram按照bpe_codes不断的合并替换。如果最终pair长度为1即没有合并的可能或者最大的pair也不再bpe_codes中,则停止循环。
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None):
"""Encode word based on list of BPE merge operations, which are applied consecutively
"""
if orig in cache:
return cache[orig]
if re.match('^({})$'.format('|'.join(glossaries)), orig):
cache[orig] = (orig,)
return (orig,)
if version == (0, 1):
word = tuple(orig) + ('',)
elif version == (0, 2): # more consistent handling of word-final segments
word = tuple(orig[:-1]) + ( orig[-1] + '',)
else:
raise NotImplementedError
pairs = get_pairs(word)
if not pairs:
return orig
while True:
bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
if bigram not in bpe_codes:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
# don't print end-of-word symbols
if word[-1] == '':
word = word[:-1]
elif word[-1].endswith(''):
word = word[:-1] + (word[-1].replace('',''),)
if vocab: ## 这里的vocab是以一定阈值,统计得到的词表,过滤掉了低频词,以减少低词频影响。
## 论文中讲到低频词可能是噪声
## 这里结合过滤低频词后的词汇表。
##因为过滤掉低词频,可能会出现oov问题,如出现oov问题,则将原词切分为更小的词。
## 更小的词,就有可能在subword词表中。
word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)
cache[orig] = word
return word
这样就完成了对输入数据的subword分词。