交叉熵损失

交叉熵损失

  • BCELoss
  • BCEWithLogitsLoss
  • CrossEntropyLoss
  • 使用场景

BCELoss

全称为Binary CrossEntropy Loss,二值交叉熵损失。

l ( x , y ) = { l 1 , . . . , l N } T l(x, y) =\{l_1,..., l_N\}^T l(x,y)={l1,...,lN}T
l n = − w n [ y n l o g ( x n ) + ( 1 − y n ) l o g ( 1 − x n ) ] l_n = -w_n[y_n log(x_n) + (1-y_n)log(1-x_n)] ln=wn[ynlog(xn)+(1yn)log(1xn)]

预测值 x x x和真值 y y y的尺寸需保持一致,均为[N, C],或者[N, C, H, W]。且取值在[0, 1]之间。所以在输入BCELoss计算损失之前,需要对预测值 x x x n n . S i g m o i d ( ) nn.Sigmoid() nn.Sigmoid()处理,且真值需要做one-hot处理,真值位置的值为1,非真值位置的值为0。如果做label smoothing处理,真值位置的值可以为0.95,非真值位置的值为0.05。

假设最终需要预测N个类别,模型输出的logits尺寸为[B, N], 用BCELoss做监督训练时,使用sigmoid做归一化处理,dim=1维度上的值是相互独立的。所以本质上BCELoss也可以处理图片的多标签分类问题。

举例:
模型预测输出 x x x的尺寸为[4, 3], 代表输入4张图,每张图预测三个类别(龙,人,中国人)的概率。

x = torch.FloatTensor(
[[0.2, 0.8, 0.7], 
 [0.8, 0.2, 0.3], 
 [0.1, 0.9, 0.2],
 [0.2, 0.8, 0.9]])
x.requires_grad = True

真值的尺寸也为[4, 3], 代表输入的4张图,其真实类别的标签,采用one-hot编码。

y = torch.FloatTensor(
[[0, 1, 1], 
 [1, 0, 0], 
 [0, 1, 0],
 [0, 1, 1]])
BCEloss = torch.nn.BCELoss()
loss = BCELoss(x, y) # tensor(0.2160)

BCEWithLogitsLoss

BCEWithLogitsLoss = sigmoid + BCELoss

所以实际模型训练时,建议直接用BCEWithLogitsLoss,一步到位。

CrossEntropyLoss

CrossEntropyLoss = softmax + log + NLLLoss

函数输入时,
对分类问题,预测值的维度为[N, C], 真值维度为[N]。
对分割类问题,预测值的维度为[N, C, H, W], 真值维度为[N, H, W]。
真值必须是torch.long类型,不能是float类型。真值的取值范围均为[0, C-1]。所以用CrossEntropyLoss时,不能用label smoothing策略。

softmax对特定维度做归一化处理,比如分割问题的模型输出尺寸为[N, C, H, W],对dim=1做softmax处理。为了数值计算稳定性,防止logits中出现的极大值,再做一次log处理。最后通过NLLoss取真值index位置上的值,取负数后,即为最终的CrossEntropyLoss。 loss越接近0,说明模型参数越能准确拟合样本。

需要注意的是在特定维度做softmax后,该维度上的值是相互关联的,故不同于BCELoss,正负样本的loss均需计算,

l n = − w n [ y n l o g ( x n ) + ( 1 − y n ) l o g ( 1 − x n ) ] l_n = -w_n[y_n log(x_n) + (1-y_n)log(1-x_n)] ln=wn[ynlog(xn)+(1yn)log(1xn)]

CrossEntropyLoss计算正样本的loss即可。
l n = − w n [ y n l o g ( x n ) ] l_n = -w_n[y_n log(x_n)] ln=wn[ynlog(xn)]

使用场景

  1. 多标签分类,只能用BCEWithLogitsLoss
  2. 单标签分类,既可以用BCEWithLogitsLoss, 也可以CrossEntropyLoss,但需要注意对真值类型的要求。建议图片分类问题用BCEWithLogitsLoss,分割问题用CrossEntropyLoss。

你可能感兴趣的:(torch,深度学习)