对于交叉熵损失函数的来由有很多资料可以参考,这里就不再赘述。本文主要尝试对交叉熵损失函数的内部运算做深度解析。
Pytorch官网中对交叉熵损失函数的介绍如下:
CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100,reduce=None, reduction=‘mean’, label_smoothing=0.0)
该损失函数计算输入值(input)和目标值(target)之间的交叉熵损失。交叉熵损失函数可用于训练一个 C C C类别的分类问题。参数weight
给定时,其为分配给每一个类别的权重的一维张量(Tensor)。当数据集分布不均衡时,这是很有用的。
函数输入(input)应包含每一个类别的原始、非标准化分数。对于未批量化的输入,输入必须是大小为 ( C ) (C) (C)的张量, ( m i n i b a t c h , C ) (minibatch,C) (minibatch,C)或 ( m i n i b a t c h , C , d 1 , d 2 , . . . , d K ) (minibatch,C,d_1 ,d_2 ,... ,d_K) (minibatch,C,d1,d2,...,dK),在K维情况下, K ≥ 1 K \geq1 K≥1。
函数目标值(target)有两种情况,本文只介绍其中较为有效的一种情况,即target为类索引。
本文以下内容均为target为类索引的情况。
函数目标值(target)取值为在 [ 0 , C ) [0,C) [0,C)之间的类索引, C C C为类别数。参数reduction
设为'none'
时,交叉熵损失可描述如下:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n l o g e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) ⋅ 1 { y n / = i g n o r e _ i n d e x } (1) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T, \\ \large l_n = -w_{yn}log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}\cdot 1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}\tag{1} l(x,y)=L={l1,...,lN}T,ln=−wynlog∑c=1Cexp(xn,c)exp(xn,yn)⋅1{yn/=ignore_index}(1)
其中, x x x是输入, y y y是目标值, w w w是weight, C C C是类别数, N N N为batch size。在reduction
不为'none'
时(默认为'mean'
),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n / = i g n o r e _ i n d e x } l n , i f r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f r e d u c t i o n = ‘ s u m ’ . (2) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{2} l(x,y)=⎩⎨⎧∑n=1N∑n=1Nwyn⋅1{yn/=ignore_index}1ln,ifreduction=‘mean’;∑n=1Nln,ifreduction=‘sum’.(2)
需要指出的是,在这种情况下的交叉熵损失等价于
LogSoftmax
和NLLLoss
的组合。1
因此,我们可以从LogSoftmax
和NLLLoss
来深度解析交叉熵损失函数的内部运算。
LogSoftmax()函数2公式如下:
L o g S o f t m a x ( x i ) = l o g ( e x p ( x i ) ∑ j e x p ( x j ) ) (3) LogSoftmax(x_i) = log(\frac{exp(x_i)}{\sum_{j}exp(x_j)}) \tag{3} LogSoftmax(xi)=log(∑jexp(xj)exp(xi))(3)
即,先对输入值进行Softmax归一化处理,然后对归一化值取对数。这部分对应公式(1)中的 log e x p ( x n , y n ) ∑ c = 1 C e x p ( x n , c ) \textcolor{red}{\log\frac{exp(x_{n,y_n})}{\sum_{c=1}^{C}exp(x_{n,c})}} log∑c=1Cexp(xn,c)exp(xn,yn)。
代码示例如下:
>>> import torch.nn as nn
>>> SM = nn.Softmax(dim=1) #Softmax函数
>>> x = torch.tensor([[1.0,3.0,4.0],[7.0,3.0,8.0],[9.0,7.0,5.0]])
>>> x
tensor([[1., 3., 4.],
[7., 3., 8.],
[9., 7., 5.]])
>>> output_SM = SM(x) #第一步,对x进行Softmax归一化处理
>>> output_SM
#每一行元素相加之和等于1
tensor([[0.0351, 0.2595, 0.7054],
[0.2676, 0.0049, 0.7275],
[0.8668, 0.1173, 0.0159]])
>>> out_L_SM = torch.log(output_SM) #第二步,对输出取log
>>> out_L_SM
tensor([[-3.3490, -1.3490, -0.3490],
[-1.3182, -5.3182, -0.3182],
[-0.1429, -2.1429, -4.1429]])
#直接使用LogSoftmax函数,一步到位
>>> L_SM = nn.LogSoftmax(dim=1)
>>> out_L_SM_ = L_SM(x)
>>> out_L_SM_
tensor([[-3.3490, -1.3490, -0.3490],
[-1.3182, -5.3182, -0.3182],
[-0.1429, -2.1429, -4.1429]])
Pytorch中的NLLLoss函数3“名不副实”,虽然名为负对数似然函数,但其内部并没有进行对数计算,而只是对输入值求平均后取负(函数参数reduction
为默认值'mean'
,参数weight
为默认值'none'
时)。
官网介绍如下:
CLASS torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’)
参数reduction
值为'none'
时:
l ( x , y ) = L = { l 1 , . . . , l N } T , l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c / = i g n o r e _ i n d e x } , (4) \large l(x,y) = L = \left \{ l_1,...,l_N \right \}^T,\ l_n = -w_{yn}x_{n,yn}, w_c = weight[c]\cdot1\left \{ c\mathrlap{\,/}{=}ignore\_index\right \},\tag{4} l(x,y)=L={l1,...,lN}T, ln=−wynxn,yn,wc=weight[c]⋅1{c/=ignore_index},(4)
其中, x x x为输入, y y y为目标值, w w w为weight, N N N为batch size。
参数reduction
值不为'none'
时(默认为'mean'
),有:
l ( x , y ) = { ∑ n = 1 N 1 ∑ n = 1 N w y n l n , i f r e d u c t i o n = ‘ m e a n ’ ; ∑ n = 1 N l n , i f r e d u c t i o n = ‘ s u m ’ . (5) \large l(x,y) = \left\{\begin{matrix} \sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn}}l_n, \quad if \, reduction=‘mean’; \\ \sum_{n=1}^{N}l_n, \quad if \, reduction=‘sum’ . \end{matrix}\right. \tag{5} l(x,y)=⎩⎨⎧∑n=1N∑n=1Nwyn1ln,ifreduction=‘mean’;∑n=1Nln,ifreduction=‘sum’.(5)
可以看出,当reduction
为'mean'
时,即是对 l n l_n ln求加权平均值。weight
参数默认为1,因此默认情况下,即是对 l n l_n ln求平均值。又 l n = − w y n x n , y n l_n = -w_{yn}x_{n,yn} ln=−wynxn,yn,所以weight
为默认值1时, l n = − x n , y n l_n=-x_{n,yn} ln=−xn,yn。故此时,即是对 x x x求平均后取负。 这部分对于公式(2)中的 ∑ n = 1 N 1 ∑ n = 1 N w y n ⋅ 1 { y n / = i g n o r e _ i n d e x } l n \textcolor{red}{\sum_{n=1}^{N}\frac{1}{\sum_{n=1}^{N}w_{yn} \cdot1\left \{ y_n\mathrlap{\,/}{=}ignore\_index \right \}}l_n} ∑n=1N∑n=1Nwyn⋅1{yn/=ignore_index}1ln。
实例代码验证如下:
>>> import torch
>>> NLLLoss = torch.nn.NLLLoss() #Pytorch负对数似然损失函数
>>> input = torch.randn(3,3)
>>>input
tensor([[1.4550, 2.3858, 1.1724],
[0.4952, 1.5870, 0.9594],
[1.4170, 0.4525, 0.2519]])
>>>target = torch.tensor([1,0,2]) #类索引目标值
>>> loss = NLLLoss(input, target)
>>> loss
tensor(-1.0443)
平均取负有: V a l u e = − 1 3 ( 2.3858 + 0.4952 + 0.2519 ) = − 1.0443 Value = -\frac{1}{3}\left ( 2.3858+0.4952+0.2519 \right ) =-1.0443 Value=−31(2.3858+0.4952+0.2519)=−1.0443
显然,平均取负结果和NLLLoss运算结果相同。
注:笔者窃以为,公式(5)中上式可写为 ∑ n = 1 N l n ∑ n = 1 N w y n \frac{\sum_{n=1}^{N}l_n}{\sum_{n=1}^{N}w_{yn}} ∑n=1Nwyn∑n=1Nln,如此则更容易理解。公式(2)同理。
本文通过将CrossEntropyLoss
拆解为LogSoftmax
和NLLLoss
两步,对交叉熵损失内部计算做了深度的解析,以更清晰地理解交叉熵损失函数。需要指出的是,本文所介绍的内容,只是对于CrossEntropyLoss的target为类索引的情况,CrossEntropyLoss的target还可以是每个类别的概率(Probabilities for each class),这种情况有所不同。
学习总结,以作分享,如有不妥,敬请指出。
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss ↩︎
https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html?highlight=logsoftmax#torch.nn.LogSoftmax ↩︎
https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss ↩︎