PyTorch之torch.nn.CrossEntropyLoss()

  1. 简介
    信息熵: 按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度
    在这里插入图片描述
    交叉熵: 使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度
    在这里插入图片描述
    多分类任务中的交叉熵损失函数
    在这里插入图片描述

  2. 代码

1)导入包

import torch
import torch.nn as nn

2)准备数据
在图片单标签分类时,输入m张图片,输出一个m x N的Tensor,其中N是分类个数。比如输入3张图片,分三类,最后的输出是一个3 x 3的Tensor,举个例子:

x_input=torch.randn(3,3)
print('x_input:\n',x_input) 
y_target=torch.tensor([1,2,0])

PyTorch之torch.nn.CrossEntropyLoss()_第1张图片
3)计算概率分布
第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
然后对每一行使用Softmax,这样可以得到每张图片的概率分布。

softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)

在这里插入图片描述
这里dim的意思是计算Softmax的维度,这里设置dim=1,可以看到每一行的加和为1。比如第一行0.1022+0.3831+0.5147=1。

4)对Softmax的结果取自然对数

log_output=torch.log(soft_output)
print('log_output:\n',log_output)

在这里插入图片描述
对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,两者是一致的。

logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)

在这里插入图片描述
5)NLLLoss
NLLLoss的结果就是把上面的输出与y_label对应的那个值拿出来,再去掉负号,再求均值。

nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)

y_target中[1, 2, 0]对应上述第一行的第二个,第二行的第三个,第三行的第1个:
(0.9594+0.4241+0.5265)/3=0.6367
在这里插入图片描述
6) CrossEntropyLoss()

crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)

在这里插入图片描述

参考链接:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://zhuanlan.zhihu.com/p/98785902
https://zhuanlan.zhihu.com/p/56638625

你可能感兴趣的:(PyTorch,python,深度学习,机器学习,算法,人工智能)