【pytorch】2.9 交叉熵损失函数 nn.CrossEntropyLoss()

一、交叉熵损失函数

交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义

假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们选择交叉熵损失函数作为目标函数,通过反向传播调整模型的权重

nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) \begin{aligned} loss(x, class) &= -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}})\\ &= -x_{[class]} + log(\sum_j e^{x_{j}}) \end{aligned} loss(x,class)=log(jexjex[class])=x[class]+log(jexj)

  • x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
  • class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
    x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值

上面铺垫完了,接下来,我们来拆解公式,理解公式:

1、首先,交叉熵损失函数中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=jexjexi

softmax 将分类的结果做了归一化:

  • 先经过 e x i e^{x_i} exi 的运算,转换为非负数
  • 再通过公式 e x i ∑ j = 0 n e x j \frac{e^{x_i}}{\sum_{j=0}^ne^{x_j}} j=0nexjexi计算出该样本被分到分类 i i i的概率,这里所有分类概率相加的总和等于1。

2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} jexjex[class],我们希望它的值是 100%

3、作为损失函数的意义是:当预测结果越接近真实值损失函数的值越接近于0

我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=log(jexjex[class]) 越接近0。

附上一张 − l o g x -log^x logx 的图
【pytorch】2.9 交叉熵损失函数 nn.CrossEntropyLoss()_第1张图片


二、nn.CrossEntropyLoss

pytorch 的交叉熵损失函数

nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

如果设置了权重参数weight,则
l o s s ( x , c l a s s ) = w e i g h t [ c l a s s ] ( − l o g ( e x [ c l a s s ] ∑ j e x j ) ) loss(x, class) = weight_{[class]}(-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}})) loss(x,class)=weight[class](log(jexjex[class]))
weigh 为每个类别的loss设置权值,常用于类别不均衡问题。weight必须是float类型的tensor,其长度要与类别C一致,即每一个类别都要设置weight


三、应用

假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict真实值label 的交叉熵损失,可以

import torch
import torch.nn as nn

# -----------------------------------------
# 定义数据: batch_size=4;  一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)

# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

【pytorch】2.9 交叉熵损失函数 nn.CrossEntropyLoss()_第2张图片


nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:

  1. softmax:对每张图片的分类结果做softmax, softmax详细介绍
  2. log:对上面的结果 取log
    (步骤1 和 步骤2 可以合并为 nn.logSoftmax() )
  3. NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn

torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])

softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()

temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output)   # tensor(1.5230)

纯手撸版本

import torch

torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])

# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)

# log
temp2 = torch.log(temp1)

# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)

print(output)    # tensor(1.5230)

你可能感兴趣的:(#,pytorch,pytorch,深度学习,python)