PyTorch中损失函数NLLLoss与CrossEntropyLoss区别

实验代码:

import torch
from torch import nn

input=torch.randn(3,3)
print(input)

sm=nn.Softmax(dim=1)

loss1=torch.nn.NLLLoss()
target=torch.tensor([0,2,1])
print(loss1(torch.log(sm(input)),target))

loss2=torch.nn.CrossEntropyLoss()
print(loss2(input,target))

实验结果:

tensor(1.6964)
tensor(1.6964)

结论:
CrossEntropyLoss = softmax + log + NLLLoss

你可能感兴趣的:(pytorch,python)