【AI数学】交叉熵损失函数CrossEntropy

交叉熵损失(CrossEntropyLoss)函数是分类任务里最常见的损失函数。

当然,CrossEntropy(以下简称“CE”)的作用不仅仅是在简单的分类任务里,比如最近大火的图文多模态模型CLIP就用到CE来进行对比学习,稍微改造一下变成Symmetrical Cross Entropy

交叉熵不具有对称性,所以交换变量位置,会导致数值不同。所以,在使用Pytorch内置交叉熵损失函数时,记得注意顺序:(label记得放在后面)

import torch.nn.functional as F
loss = F.cross_entropy(logits, label)

简单来说,交叉熵可以用来描述两个分布的差别。先看公式,
C r o s s E n t r o p y ( p , q ) = − Σ k = 1 K l o g ( p ( k ) ) q ( k ) CrossEntropy(p, q) = -\Sigma_{k=1}^K log(p(k))q(k) CrossEntropy(p,q)=Σk=1Klog(p(k))q(k)
这里的p是我们模型输出,而q是我们的label。
通常,在分类任务里,我们的label通常是one-hot类型,假设我们应对三分类任务,有label = [1, 0, 0]代表第0类是正例。那么,CE就变得好算了:

logits = [0.3, 0.5, 0.2]
label = [1, 0, 0]
CE = -(1*log(0.3)) - (0*log(0.5)) - (0*log(0.2)) = -log(0.3)

上面是一段伪代码,代表CE的计算。在Pytorch里,CE里面有一个暗坑,没注意就会犯错,那就是CE会先给你做一个softmax(不了解softmax的,可戳《通俗易懂的Softmax》)再进行上述计算。我们看实验代码便知:

import torch
from math import log
import torch.nn.functional as F

logits = torch.Tensor([[0.8, 0.3, 1.2]]) # pre-softmax
label = torch.Tensor([[1, 0, 0]])
loss = F.cross_entropy(logits, label)
print(loss)

softmax_logits = F.softmax(logits, dim=-1) # post-softmax
loss_check = - (log(softmax_logits[0][0]))
print(loss_check)

上面代码的lossloss_check是相等的,由此可知,Pytorch内置的CE函数会将logits加一层softmax,所以,咱们不用画蛇添足在输入CE前再加一个softmax了。


进阶

面对一个batch时,label有两种表达方式。

  1. logits具有相同的shape:
logits = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
label = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
loss = F.softmax(logits, label, reduction='mean')

比较好理解,分别算logits[i]和label[i]的CE,最后求平均,或者设置reduction='sum'求和;

  1. 假设batch size为n,label的shape可以为(1, n)
logits = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
label = torch.arange(3)
loss = F.softmax(logits, label, reduction='mean')

这段代码和上面的代码是等效的当label的维度低于logits纬度时,CE函数会把label理解成one-hot的简写,比如label = [1, 2, 0]等效于[[0, 1, 0], [0, 0, 1], [1, 0, 0]]。
第二种用法属实诡异,一不小心就理解不对代码的意思。

参考:https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

你可能感兴趣的:(AI数学,人工智能,深度学习,python,交叉熵,CrossEntropy)