n n n 元语法(n-gram)是指文本中连续出现的 n n n 个词元。当 n n n 分别为 1 , 2 , 3 1,2,3 1,2,3 时,n-gram 又叫作 unigram(一元语法)、bigram(二元语法)和 trigram(三元语法)。
n n n 元语法模型是基于 n − 1 n-1 n−1 阶马尔可夫链的一种概率语言模型(即只考虑前 n − 1 n-1 n−1 个词出现的情况下,后一个词出现的概率):
unigram: P ( w 1 , w 2 , ⋯ , w T ) = ∏ i = 1 T P ( w i ) bigram: P ( w 1 , w 2 , ⋯ , w T ) = P ( x 1 ) ∏ i = 1 T − 1 P ( w i + 1 ∣ w i ) trigram: P ( w 1 , w 2 , ⋯ , w T ) = P ( x 1 ) P ( x 2 ∣ x 1 ) ∏ i = 1 T − 2 P ( w i + 2 ∣ w i , w i + 1 ) \begin{aligned} \text{unigram:}\quad&P(w_1,w_2,\cdots,w_T)=\prod_{i=1}^T P(w_i) \\ \text{bigram:}\quad&P(w_1,w_2,\cdots,w_T)=P(x_1)\prod_{i=1}^{T-1} P(w_{i+1}|w_i) \\ \text{trigram:}\quad&P(w_1,w_2,\cdots,w_T)=P(x_1)P(x_2|x_1)\prod_{i=1}^{T-2} P(w_{i+2}|w_{i},w_{i+1}) \\ \end{aligned} unigram:bigram:trigram:P(w1,w2,⋯,wT)=i=1∏TP(wi)P(w1,w2,⋯,wT)=P(x1)i=1∏T−1P(wi+1∣wi)P(w1,w2,⋯,wT)=P(x1)P(x2∣x1)i=1∏T−2P(wi+2∣wi,wi+1)
BLEU(发音与单词 blue 相同) 最早是用于评估机器翻译的结果, 但现在它已经被广泛用于评估许多应用的输出序列的质量。对于预测序列 pred
中的任意 n n n 元语法, BLEU 的评估都是这个 n n n 元语法是否出现在标签序列 label
中。
BLEU 定义如下:
BLEU = exp ( min ( 0 , 1 − len(label) len(pred) ) ) ∏ n = 1 k p n 1 / 2 n \text{BLEU}=\exp\left(\min\left(0,1-\frac{\text{len(label)}}{\text{len(pred)}}\right)\right)\prod_{n=1}^kp_n^{1/2^n} BLEU=exp(min(0,1−len(pred)len(label)))n=1∏kpn1/2n
其中 len(*) \text{len(*)} len(*) 代表序列 ∗ * ∗ 中的词元个数, k k k 用于匹配最长的 n n n 元语法(常取 4 4 4), p n p_n pn 表示 n n n 元语法的精确度。
具体而言,给定 label
: A , B , C , D , E , F A,B,C,D,E,F A,B,C,D,E,F 和 pred
: A , B , B , C , D A,B,B,C,D A,B,B,C,D,取 k = 3 k=3 k=3。
首先看 p 1 p_1 p1 如何计算。我们先将 pred
中的每个 unigram 都统计出来: ( A ) , ( B ) , ( B ) , ( C ) , ( D ) (A),(B),(B),(C),(D) (A),(B),(B),(C),(D),再将 label
中的每个 unigram 都统计出来: ( A ) , ( B ) , ( C ) , ( D ) , ( E ) , ( F ) (A),(B),(C),(D),(E),(F) (A),(B),(C),(D),(E),(F),然后看它们之间有多少匹配的(不可以重复匹配,即必须保持一一对应的关系)。可以看出一共有 4 4 4 个匹配的,而 pred
中一共有 5 5 5 个 unigram,于是 p 1 = 4 / 5 p_1=4/5 p1=4/5。
再来看 p 2 p_2 p2 如何计算。我们先将 pred
中的每个 bigram 都统计出来: ( A , B ) , ( B , B ) , ( B , C ) , ( C , D ) (A,B),(B,B),(B,C),(C,D) (A,B),(B,B),(B,C),(C,D),再将 label
中的每个 bigram 都统计出来: ( A , B ) , ( B , C ) , ( C , D ) , ( D , E ) , ( E , F ) (A,B),(B,C),(C,D),(D,E),(E,F) (A,B),(B,C),(C,D),(D,E),(E,F),然后看它们之间有多少匹配的。可以看出一共有 3 3 3 个匹配的,而 pred
中一共有 4 4 4 个 bigram,于是 p 2 = 3 / 4 p_2=3/4 p2=3/4。
最后看 p 3 p_3 p3 如何计算。我们先将 pred
中的每个 trigram 都统计出来: ( A , B , B ) , ( B , B , C ) , ( B , C , D ) (A,B,B),(B,B,C),(B,C,D) (A,B,B),(B,B,C),(B,C,D),再将 label
中的每个 trigram 都统计出来: ( A , B , C ) , ( B , C , D ) , ( C , D , E ) , ( D , E , F ) (A,B,C),(B,C,D),(C,D,E),(D,E,F) (A,B,C),(B,C,D),(C,D,E),(D,E,F),然后看它们之间有多少匹配的。可以看出只有 1 1 1 个匹配,而 pred
中一共有 3 3 3 个 trigram,于是 p 3 = 1 / 3 p_3=1/3 p3=1/3。
因此此例的 BLEU 分数为
BLEU = exp ( min ( 0 , 1 − 6 / 5 ) ) ⋅ p 1 1 / 2 ⋅ p 2 1 / 4 ⋅ p 3 1 / 8 = e − 0.2 ⋅ ( 4 5 ) 1 / 2 ⋅ ( 3 4 ) 1 / 4 ⋅ ( 1 3 ) 1 / 8 ≈ 0.5940 \begin{aligned} \text{BLEU}&=\exp(\min(0,1-6/5))\cdot p_1^{1/2}\cdot p_2^{1/4}\cdot p_3^{1/8} \\ &=e^{-0.2}\cdot \left(\frac45\right)^{1/2}\cdot \left(\frac34\right)^{1/4}\cdot\left(\frac13\right)^{1/8} \\ &\approx0.5940 \end{aligned} BLEU=exp(min(0,1−6/5))⋅p11/2⋅p21/4⋅p31/8=e−0.2⋅(54)1/2⋅(43)1/4⋅(31)1/8≈0.5940
根据 BLEU 的定义,当预测序列与标签序列完全相同时,BLEU 的值为 1 1 1。另一方面,由于 e x > 0 e^x>0 ex>0 且 p n ≥ 0 p_n\geq0 pn≥0,因此有
BLEU ∈ [ 0 , 1 ] \text{BLEU}\in[0,1] BLEU∈[0,1]
BLEU 的值越接近 1 1 1,则代表预测效果越好;BLEU 的值越接近 0 0 0,则代表预测效果越差。
此外,由于 n n n 元语法越长匹配难度越大, 所以 BLEU 为更长的 n n n 元语法的精确度分配更大的权重(固定 a ∈ ( 0 , 1 ) a\in(0,1) a∈(0,1),则 a 1 / 2 n a^{1/2^n} a1/2n 会随着 n n n 的增加而增加)。而且,由于预测序列越短获得的 p n p_n pn 值越高,所以系数 exp ( ⋅ ) \exp(\cdot) exp(⋅) 这一项用于惩罚较短的预测序列。
import math
from collections import Counter
def bleu(label, pred, k=4):
# 我们假设输入的label和pred都已经进行了分词
score = math.exp(min(0, 1 - len(label) / len(pred)))
for n in range(1, k + 1):
# 使用哈希表用来存放label中所有的n-gram
hashtable = Counter([' '.join(label[i:i + n]) for i in range(len(label) - n + 1)])
# 匹配成功的个数
num_matches = 0
for i in range(len(pred) - n + 1):
ngram = ' '.join(pred[i:i + n])
if ngram in hashtable and hashtable[ngram] > 0:
num_matches += 1
hashtable[ngram] -= 1
score *= math.pow(num_matches / (len(pred) - n + 1), math.pow(0.5, n))
return score
例如:
label = 'A B C D E F'
pred = 'A B B C D'
for i in range(4):
print(bleu(label.split(), pred.split(), k=i + 1))
# 0.7322950476607851
# 0.6814773296495302
# 0.5940339360503315
# 0.0
[1] d2l. Sequence to Sequence Learning