Pytorch 中损失函数详解

参考链接:详解torch.nn.NLLLOSS - 知乎分类问题的损失函数中,经常会遇到torch.nn.NLLLOSS。torch.nn.NLLLOSS通常不被独立当作损失函数,而需要和softmax、log等运算组合当作损失函数。 torch.nn.NLLLOSS官方链接: NLLLoss - PyTorch 1.9.0 documentat…https://zhuanlan.zhihu.com/p/383044774

一、torch.nn.NLLLOSS运算规则

from torch import nn
import torch

# nllloss首先需要初始化
nllloss = nn.NLLLoss() # 可选参数中有 reduction='mean', 'sum', 默认mean

在使用nllloss时,需要有两个张量,一个是预测向量,一个是label

predict = torch.Tensor([[2, 3, 1]])  # shape: (n, category)
label = torch.tensor([1]) # shape: (n,)
  • 这里解释一下predict和label,label的shape是n,表示了n个向量对应的正确类别,比如这里label为1,则表明向量(2,3,1)对应的类别是1;
  • predict则表示每个类别预测的概率,比如向量(2,3,1)则表示类别0,1,2预测的概率分别为(2,3,1)(先忽略概率大于1的问题)
  1. predict shape为(1, category)的情况
# 
predict = torch.Tensor([[2, 3, 1]])
label = torch.tensor([1])
nllloss(predict, label)
# output: tensor(-3.)

nllloss对两个向量的操作为,将predict中的向量,在label中对应的index取出,并取负号输出。label中为1,则取2,3,1中的第1位3,取负号后输出

2. predict shape为(n, category)的情况

predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(-6)

nllloss对两个向量的操作为,继续将predict中的向量,在label中对应的index取出,并取负号输出。label中为1,则取2,3,1中的第1位3,label第二位为2,则取出3,7,9的第2位9,将两数取平均后加负号后输出

这时就可以看到最开始的nllloss初始化的时候,如果参数reduction取'mean',就是上述结果。如果reduction取'sum',那么各行取出对应的结果,就是取sum后输出,如下所示:

nllloss = nn.NLLLoss( reduction='sum')
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(-12)

二、与torch.nn.CrossEntropyLoss的区别

torch.nn.CrossEntropyLoss相当于softmax + log + nllloss。

上面的例子中,预测的概率大于1明显不符合预期,可以使用softmax归一,取log后是交叉熵,取负号是为了符合loss越小,预测概率越大。

所以使用nll loss时,可以这样操作

nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(0.2684)

而使用torch.nn.CrossEntropyLoss可以省去softmax + log

cross_loss = nn.CrossEntropyLoss()

predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
# output: tensor(0.2684)

你可能感兴趣的:(Pytorch,pytorch)