由mnist引发的思考,pytorch中的交叉熵误差函数nn.CrossEntropy做了什么?

文章目录

  • 引入
  • 实验一
  • 实验二
  • 结论

引入

在MNIST手写体实验中,关于在交叉熵损失函数计算误差时,神经网络输出为10个,当标签设置为何种情况时才能满足交叉熵损失函数的计算公式,来探究这个问题。

实验一

直接打印出每个数据的标签内容

代码如下:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([  # 设置预处理的方式,里面依次填写预处理的方法
    transforms.ToTensor(),  # 将数据转换为tensor对象
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
if __name__ == '__main__':
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2)
    for i, data in enumerate(trainloader, 0):
        data_, label = data
        print("label:"+str(label.numpy()))
        if i == 10:
            break

label:[6]
label:[7]
label:[2]
label:[3]
label:[5]
label:[5]
label:[7]
label:[6]
label:[1]
label:[6]
label:[9]

从上面的实验看出label就是该手写体所代表的数字,那么nn.CrossEntorpy是如何对只有一个值的tensor进行计算的?

实验二

经过查阅资料,得出label中的tensor代表的是所有预测种类中的第几类,例如上面实验从的第一个label:[6]就代表为第6类。


该函数的数学公式可以写为:

由mnist引发的思考,pytorch中的交叉熵误差函数nn.CrossEntropy做了什么?_第1张图片
其中x代表损失函数式的输入,y代表target(或标签)中所代表的类别,C为所有的类别数量。

import numpy as np
import torch
from torch import nn
from torchvision import datasets, transforms

transform = transforms.Compose([  # 设置预处理的方式,里面依次填写预处理的方法
    transforms.ToTensor(),  # 将数据转换为tensor对象
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
if __name__ == '__main__':
    inputs = torch.Tensor([[-0.5, -0.2, -0.3]])
    target = torch.tensor([0])
    criterion = nn.CrossEntropyLoss()
    output = criterion(inputs, target)
    print("通过nn.CrossEntropy计算的结果", output)  # 使用nn.CrossEntropy函数
    # 直接计算
    my_out = -inputs[:, 0] + np.log(torch.sum(torch.exp(inputs[0:1])))
    print("通过公式计算的结果:", my_out)
    # 通过logSoftmax和NLLose计算的结果
    log_softmax_function = nn.LogSoftmax(dim=1)
    loss = nn.NLLLoss()
    logSoftmax_NLLose_output = loss(log_softmax_function(inputs), target)
    print("通过LogSoftmax和NLLLose函数计算的结果", logSoftmax_NLLose_output)

通过nn.CrossEntropy计算的结果 tensor(1.2729)
通过公式计算的结果: tensor([1.2729])
通过LogSoftmax和NLLLose函数计算的结果 tensor(1.2729)

结论

PYTORCH中的CrossEntropy函数结合了取log的softmax函数和NLLOSE误差函数来计算loss,label中只用给出结果所分的类别编号即可。

你可能感兴趣的:(机器学习,pytorch,深度学习,神经网络)