PPL(Perplexity) 是用在自然语言处理领域(NLP)中,衡量语言模型好坏的指标。它主要是根据每个词来估计一句话出现的概率,并用句子长度作normalize。
S = W 1 , W 2 , W 3 , . . . , W k S = W_1,W_2,W_3,...,W_k S=W1,W2,W3,...,Wk
P ( S ) = P ( W 1 , W 2 , W 3 , . . . , W k ) = p ( W 1 ) p ( W 2 ∣ W 1 ) . . . p ( W k ∣ W 1 , W 2 , W 3 , . . . , W k − 1 ) P(S) = P(W_1,W_2,W_3,...,W_k)= p(W_1)p(W_2|W_1)...p(W_k|W_1,W_2,W_3,...,W_{k-1}) P(S)=P(W1,W2,W3,...,Wk)=p(W1)p(W2∣W1)...p(Wk∣W1,W2,W3,...,Wk−1)
困惑度与测试集上的句子概率相关,其基本思想是:给测试集的句子赋予较高概率值的语言模型较好,当语言模型训练完之后,测试集中的句子都是正常的句子,那么训练好的模型就是在测试集上的概率越高越好。
下面讲一下其基础公式:
P P ( W ) = P ( w 1 w 2 w 3 . . . w N ) − 1 N = 1 P ( w 1 w 2 w 3 . . . w N ) N PP(W)=P(w_1w_2w_3...w_N)^{-\frac{1}{N}}\\ = \sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}} PP(W)=P(w1w2w3...wN)−N1=NP(w1w2w3...wN)1
这里补充一下公式的细节:
根号内是句子概率的倒数,所以显然 句子越好(概率大),困惑度越小,也就是模型对句子越不困惑。 这样我们也就理解了这个指标的名字。
开N次根号(N为句子长度)意味着几何平均数(把句子概率拆成字符概率的连乘)
需要平均的原因是,因为每个字符的概率必然小于1,所以越长的句子的概率在连乘的情况下必然越小,所以为了对长短句公平,需要平均一下
是几何平均的原因,是因为其的特点是,如果有其中的一个概率是很小的,那么最终的结果就不可能很大,从而要求好的句子的每个字符都要有基本让人满意的概率 [2]
当然,这是在数学领域内计算困惑度的公式,在实际的代码层面,用的是另一套公式,需要将上述公式进行转换,下面我就详细来介绍一下:
P P ( W ) = 2 H ( W ) = 2 − 1 N log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) PP(W) = 2^{H(W)}=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)} PP(W)=2H(W)=2−N1log2P(w1,w2,w3,...,wN)
P P ( W ) = 2 − 1 N log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) = ( 2 log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) ) − 1 N = P ( w 1 , w 2 , w 3 , . . . , w N ) − 1 N = 1 P ( w 1 w 2 w 3 . . . w N ) N PP(W)=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)}\\ =(2^{\log_2P(w_1,w_2,w_3,...,w_N)})^{-\frac{1}{N}}\\ =P(w_1,w_2,w_3,...,w_N)^{-\frac{1}{N}}\\ =\sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}} PP(W)=2−N1log2P(w1,w2,w3,...,wN)=(2log2P(w1,w2,w3,...,wN))−N1=P(w1,w2,w3,...,wN)−N1=NP(w1w2w3...wN)1
PP(W)
在本质上就是变成了交叉熵加一个底数的指数函数,所以当我们要求困惑度,就可以直接求交叉熵了。我们来看一下代码具体是怎么实现困惑度的。
probs = np.take(probs, target, axis=1).diagonal()
total += -np.sum(np.log(probs))
count += probs.size
perplexity = np.exp(total / count)
其实核心代码就这四行。
N
以上就是我个人对于困惑度查询资料以及完成代码之后做出的理解。