pytorch如何定义损失函数_PyTorch学习笔记——多分类交叉熵损失函数

pytorch如何定义损失函数_PyTorch学习笔记——多分类交叉熵损失函数_第1张图片

理解交叉熵

关于样本集的两个概率分布p和q,设p为真实的分布,比如[1, 0, 0]表示当前样本属于第一类,q为拟合的分布,比如[0.7, 0.2, 0.1]。

按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):

如果使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度(交叉熵):

直观上,用p来描述样本是最完美的,用q描述样本就不那么完美,根据吉布斯不等式,

恒成立,当q为真实分布时取等,我们将由q得到的平均编码长度比由p得到的平均编码长度多出的bit数称为
相对熵,也叫KL散度:

在机器学习的分类问题中,我们希望缩小模型预测和标签之间的差距,即KL散度越小越好,在这里由于KL散度中的

项不变(
在其他问题中未必),故在优化过程中只需要关注交叉熵就可以了,因此一般使用交叉熵作为损失函数。

多分类任务中的交叉熵损失函数

其中

是一个概率分布,每个元素
表示样本属于第i类的概率;
是样本标签的onehot表示,当样本属于第类别i时
,否则
;c是样本标签。

PyTorch中的交叉熵损失函数实现

PyTorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。

  • torch.nn.CrossEntropyLoss()

类定义如下

torch

表示一个样本的
非softmax输出,c表示该样本的标签,则损失函数公式描述如下,

如果weight被指定,

其中,

import 
  • torch.nn.NLLLoss()

类定义如下

torch

表示一个样本对每个类别的对数似然(log-probabilities),c表示该样本的标签,损失函数公式描述如下,

其中,

import 
  • 总结

torch.nn.CrossEntropyLoss在一个类中组合了nn.LogSoftmax和nn.NLLLoss,

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. The input is expected to contain scores for each class.

你可能感兴趣的:(pytorch如何定义损失函数,交叉熵损失函数和focal,loss)