交叉熵损失函数分类_PyTorch学习笔记——多分类交叉熵损失函数

理解交叉熵

关于样本集的两个概率分布p和q,设p为真实的分布,比如[1, 0, 0]表示当前样本属于第一类,q为拟合的分布,比如[0.7, 0.2, 0.1]。

按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度(信息熵):

如果使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度(交叉熵):

直观上,用p来描述样本是最完美的,用q描述样本就不那么完美,根据吉布斯不等式,

恒成立,当q为真实分布时取等,我们将由q得到的平均编码长度比由p得到的平均编码长度多出的bit数称为相对熵,也叫KL散度:

在机器学习的分类问题中,我们希望缩小模型预测和标签之间的差距,即KL散度越小越好,在这里由于KL散度中的

项不变(在其他问题中未必),故在优化过程中只需要关注交叉熵就可以了,因此一般使用交叉熵作为损失函数。

多分类任务中的交叉熵损失函数

其中

是一个概率分布,每个元素

表示样本属于第i类的概率;

是样本标签的onehot表示,当样本属于第类别i时

,否则

;c是样本标签。

PyTorch中的交叉熵损失函数实现

PyTorch提供了两个类来计算交叉熵,分别是CrossEntropyLoss() 和NLLLoss()。torch.nn.CrossEntropyLoss()

类定义如下

torch.nn.CrossEntropyLoss(

weight=None,

ignore_index=-100,

reduction="mean",

)

表示一个样本的非softmax输出,c表示该样本的标签,则损失函数公式描述如下,

如果weight被指定,

其中,

import torch

import torch.nn as nn

model = nn.Linear(10, 3)

criterion = nn.CrossEntropyLoss()

x = torch.randn(16, 10)

y = torch.randint(0, 3, size=(16,)) # (16, )

logits = model(x) # (16, 3)

loss = criterion(logits, y)torch.nn.NLLLoss()

类定义如下

torch.nn.NLLLoss(

weight=None,

ignore_index=-100,

reduction="mean",

)

表示一个样本对每个类别的对数似然(log-probabilities),c表示该样本的标签,损失函数公式描述如下,

其中,

import torch

import torch.nn as nn

model = nn.Sequential(

nn.Linear(10, 3),

nn.LogSoftmax()

)

criterion = nn.NLLLoss()

x = torch.randn(16, 10)

y = torch.randint(0, 3, size=(16,)) # (16, )

out = model(x) # (16, 3)

loss = criterion(out, y)总结

torch.nn.CrossEntropyLoss在一个类中组合了nn.LogSoftmax和nn.NLLLoss,This criterion combines nn.LogSoftmax() and nn.NLLLoss()in one single class. The input is expected to contain scores for each class.

你可能感兴趣的:(交叉熵损失函数分类)