Perplexity

困惑度(Perplexity):评价语言模型的指标

1.定义

PPL(Perplexity) 是用在自然语言处理领域(NLP)中,衡量语言模型好坏的指标。它主要是根据每个词来估计一句话出现的概率,并用句子长度作normalize。

  • 其本质上就是计算句子的概率,例如对于句子S(词语w的序列):

S = W 1 , W 2 , W 3 , . . . , W k S = W_1,W_2,W_3,...,W_k S=W1,W2W3,...,Wk

  • 它的概率为:

P ( S ) = P ( W 1 , W 2 , W 3 , . . . , W k ) = p ( W 1 ) p ( W 2 ∣ W 1 ) . . . p ( W k ∣ W 1 , W 2 , W 3 , . . . , W k − 1 ) P(S) = P(W_1,W_2,W_3,...,W_k)= p(W_1)p(W_2|W_1)...p(W_k|W_1,W_2,W_3,...,W_{k-1}) P(S)=P(W1,W2W3,...,Wk)=p(W1)p(W2W1)...p(WkW1,W2W3,...,Wk1)

困惑度与测试集上的句子概率相关,其基本思想是:给测试集的句子赋予较高概率值的语言模型较好,当语言模型训练完之后,测试集中的句子都是正常的句子,那么训练好的模型就是在测试集上的概率越高越好

  • 通俗点来讲,假设词库里有10个单词,那么对于一个完全没有训练过的模型,其预测一个特定单词的概率就是1/10,概率是均等分的,这时候我们就能得出其困惑度为10,也就是模式是完全糊涂的,没有任何分辨能力。但是当模型能将一个特定单词预测出1/2的概率时,就代表模型能从10个单词中挑选出2个可能对的单词,这时候模型的困惑度就是2,说明模型有了一定的分辨能力。当然,这么简单的求倒数获取困惑度的前提是概率是均等的,如果概率不均等,那么困惑度和预测的倒数就不是相等关系了。
  • 当然,最好的就是模型能识别出那个正确的单词,给予100%的概率,这时候模型的困惑度就是1,代表模型没有任何困惑,是完全清楚的,可以正确识别单词,也就是能正确识别一个句子。

2.公式

下面讲一下其基础公式:
P P ( W ) = P ( w 1 w 2 w 3 . . . w N ) − 1 N = 1 P ( w 1 w 2 w 3 . . . w N ) N PP(W)=P(w_1w_2w_3...w_N)^{-\frac{1}{N}}\\ = \sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}} PP(W)=P(w1w2w3...wN)N1=NP(w1w2w3...wN)1
这里补充一下公式的细节:

  • 根号内是句子概率的倒数,所以显然 句子越好(概率大),困惑度越小,也就是模型对句子越不困惑。 这样我们也就理解了这个指标的名字。

  • 开N次根号(N为句子长度)意味着几何平均数(把句子概率拆成字符概率的连乘)

    • 需要平均的原因是,因为每个字符的概率必然小于1,所以越长的句子的概率在连乘的情况下必然越小,所以为了对长短句公平,需要平均一下

    • 几何平均的原因,是因为其的特点是,如果有其中的一个概率是很小的,那么最终的结果就不可能很大,从而要求好的句子的每个字符都要有基本让人满意的概率 [2]

      • 机器翻译常用指标BLEU也使用了几何平均,还有机器学习常用的F score使用的调和平均数 ,也有类似的效果。

当然,这是在数学领域内计算困惑度的公式,在实际的代码层面,用的是另一套公式,需要将上述公式进行转换,下面我就详细来介绍一下:

  • 在真实的代码计算中,上述的公式很难计算,但是就是有大佬发现,其实上述的公式可以转化为求交叉熵的公式。而背后的原理是,不管是困惑度,还是交叉熵,其本质上都是在计算信息熵,所以都是在计算模型的混乱程度,因此两者在数学意义的转换就有了理论依据,下面看一下公式转换过程:

P P ( W ) = 2 H ( W ) = 2 − 1 N log ⁡ 2 P ( w 1 , w 2 , w 3 , . . . , w N ) PP(W) = 2^{H(W)}=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)} PP(W)=2H(W)=2N1log2P(w1,w2,w3,...,wN)

P P ( W ) = 2 − 1 N log ⁡ 2 P ( w 1 , w 2 , w 3 , . . . , w N ) = ( 2 log ⁡ 2 P ( w 1 , w 2 , w 3 , . . . , w N ) ) − 1 N = P ( w 1 , w 2 , w 3 , . . . , w N ) − 1 N = 1 P ( w 1 w 2 w 3 . . . w N ) N PP(W)=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)}\\ =(2^{\log_2P(w_1,w_2,w_3,...,w_N)})^{-\frac{1}{N}}\\ =P(w_1,w_2,w_3,...,w_N)^{-\frac{1}{N}}\\ =\sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}} PP(W)=2N1log2P(w1,w2,w3,...,wN)=(2log2P(w1,w2,w3,...,wN))N1=P(w1,w2,w3,...,wN)N1=NP(w1w2w3...wN)1

  • 从上面可以看出,PP(W)在本质上就是变成了交叉熵加一个底数的指数函数,所以当我们要求困惑度,就可以直接求交叉熵了。
  • 这里还有一个细节,这个底数和log是配套的,在公式中间可以直接消掉,所以底数的大小并不重要,这里选了2,换一个我也可以使用e,这并无关系。

3.代码

我们来看一下代码具体是怎么实现困惑度的。

probs = np.take(probs, target, axis=1).diagonal()
total += -np.sum(np.log(probs))
count += probs.size
perplexity = np.exp(total / count)

其实核心代码就这四行。

  • 第一行,先求出 P ( w 1 , w 2 , w 3 , . . . , w N ) P(w_1,w_2,w_3,...,w_N) P(w1,w2,w3,...,wN),也就是求交叉熵。
  • 第二行,对应的代码是 − log ⁡ 2 P ( w 1 , w 2 , w 3 , . . . , w N ) -\log_2P(w_1,w_2,w_3,...,w_N) log2P(w1,w2,w3,...,wN)
  • 第三行,对应的代码是求N
  • 第四行,对应的代码就是 2 − 1 N log ⁡ 2 P ( w 1 , w 2 , w 3 , . . . , w N ) 2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)} 2N1log2P(w1,w2,w3,...,wN)

以上就是我个人对于困惑度查询资料以及完成代码之后做出的理解。

你可能感兴趣的:(自然语言处理,深度学习,语言模型)