交叉熵损失(CrossEntropyLoss)函数是分类任务里最常见的损失函数。
当然,CrossEntropy(以下简称“CE”)的作用不仅仅是在简单的分类任务里,比如最近大火的图文多模态模型CLIP就用到CE来进行对比学习,稍微改造一下变成Symmetrical Cross Entropy。
交叉熵不具有对称性,所以交换变量位置,会导致数值不同。所以,在使用Pytorch内置交叉熵损失函数时,记得注意顺序:(label记得放在后面)
import torch.nn.functional as F
loss = F.cross_entropy(logits, label)
简单来说,交叉熵可以用来描述两个分布的差别。先看公式,
C r o s s E n t r o p y ( p , q ) = − Σ k = 1 K l o g ( p ( k ) ) q ( k ) CrossEntropy(p, q) = -\Sigma_{k=1}^K log(p(k))q(k) CrossEntropy(p,q)=−Σk=1Klog(p(k))q(k)
这里的p
是我们模型输出,而q
是我们的label。
通常,在分类任务里,我们的label通常是one-hot
类型,假设我们应对三分类任务,有label = [1, 0, 0]代表第0类是正例。那么,CE就变得好算了:
logits = [0.3, 0.5, 0.2]
label = [1, 0, 0]
CE = -(1*log(0.3)) - (0*log(0.5)) - (0*log(0.2)) = -log(0.3)
上面是一段伪代码,代表CE的计算。在Pytorch里,CE里面有一个暗坑,没注意就会犯错,那就是CE会先给你做一个softmax(不了解softmax的,可戳《通俗易懂的Softmax》)再进行上述计算。我们看实验代码便知:
import torch
from math import log
import torch.nn.functional as F
logits = torch.Tensor([[0.8, 0.3, 1.2]]) # pre-softmax
label = torch.Tensor([[1, 0, 0]])
loss = F.cross_entropy(logits, label)
print(loss)
softmax_logits = F.softmax(logits, dim=-1) # post-softmax
loss_check = - (log(softmax_logits[0][0]))
print(loss_check)
上面代码的loss
和loss_check
是相等的,由此可知,Pytorch内置的CE函数会将logits加一层softmax,所以,咱们不用画蛇添足在输入CE前再加一个softmax了。
面对一个batch时,label有两种表达方式。
logits
具有相同的shape:logits = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
label = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
loss = F.softmax(logits, label, reduction='mean')
比较好理解,分别算logits[i]和label[i]的CE,最后求平均,或者设置reduction='sum'
求和;
n
,label的shape可以为(1, n)
;logits = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
label = torch.arange(3)
loss = F.softmax(logits, label, reduction='mean')
这段代码和上面的代码是等效的。当label的维度低于logits纬度时,CE函数会把label理解成one-hot的简写,比如label = [1, 2, 0]等效于[[0, 1, 0], [0, 0, 1], [1, 0, 0]]。
第二种用法属实诡异,一不小心就理解不对代码的意思。
参考:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss