参考:softmax + log = logsoftmax, logsoftmax+ nllloss= crossentropy_LUQC638的博客-CSDN博客import torchimport torch.nn as nnimport torch.nn.functional as F# Example of target with class indicesinput = torch.randn(3, 5)print(f"Input is {input}")t = torch.tensor([1,1,2])s = nn.Softmax(dim=1)print(f"The softmax result is {s(input)}")l = ..https://blog.csdn.net/weixin_61445075/article/details/122205690?spm=1001.2014.3001.5501
torch.nn 和torch.nn.functional 都可以进行droupout,relu,logsoftmax,crossentropy等操作,原理不一样,functional 中是直接用函数去操作要执行的tensor,而nn是先创建一个对象,用对象去操作要执行的tensor。
import torch
import torch.nn as nn
import torch.nn.functional as F
if __name__ == '__main__':
# Example of target with class indices
input = torch.randn(3, 5)
print(f"Input is {input}")
t = torch.tensor([1, 1, 2])
# s = nn.Softmax(dim=1)
# print(f"The softmax result is {s(input)}")
s = F.softmax(input, dim=1)
print(f"The softmax result is {s}")
# l = torch.log(s(input))
l = torch.log(s)
print(f"softmax + Log result is {l}")
# ls = nn.LogSoftmax(dim=1)
# print(f"The logsoftmax is {ls(input)}")
ls = F.log_softmax(input, dim=1)
print(f"The logsoftmax is {ls}")
nl = F.nll_loss(l, t)
print(f"The softmax + log + nllloss is {nl}")
ce = F.cross_entropy(input, t)
print(f"The cross entropy result is {ce}")
xent = nn.CrossEntropyLoss()
print(f"The cross entropy result is {xent(input,t)}")
nn.Softmax,nn.LogSoftmax, nn.CrossEntropyLoss() 是构建了相关的对象,获取最终的值需要对象去调用输入。
F.softmax, F.log_softmax,F.Cross_entropy 是函数,直接以输入作为参数。
Input is tensor([[-0.2075, 0.9839, 1.4806, -0.8558, -1.3228],
[ 1.5713, 1.8871, 0.9782, 0.1114, 0.5329],
[-1.4224, -0.1338, -0.2187, -0.3194, -1.5902]])
The softmax result is tensor([[0.0948, 0.3120, 0.5126, 0.0496, 0.0311],
[0.2849, 0.3907, 0.1574, 0.0662, 0.1009],
[0.0846, 0.3069, 0.2820, 0.2550, 0.0715]])
softmax + Log result is tensor([[-2.3563, -1.1649, -0.6682, -3.0046, -3.4716],
[-1.2557, -0.9399, -1.8488, -2.7156, -2.2941],
[-2.4697, -1.1811, -1.2660, -1.3667, -2.6375]])
The logsoftmax is tensor([[-2.3563, -1.1649, -0.6682, -3.0046, -3.4716],
[-1.2557, -0.9399, -1.8488, -2.7156, -2.2941],
[-2.4697, -1.1811, -1.2660, -1.3667, -2.6375]])
The softmax + log + nllloss is 1.123583436012268
The cross entropy result is 1.123583436012268
The cross entropy result is 1.123583436012268
殊途同归。