是
nn.Softmax与torch.log的整合。import torch
import torch.nn as nn
# 此处假设batch_size = 1
x_input = torch.randn(3, 4) # 随机生成输入,预测3个对象、4个类别,每个对象分别属于4个类别的概率
print('x_input:\n', x_input)
y_target = torch.tensor([1, 2, 0]) # 此处的意思为第一个对象属于第1类,第二个对象属于第2类,第三个对象属于第0类
# 计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func = nn.Softmax(dim=1) #dim=0代表列,dim=1代表行
soft_output = softmax_func(x_input)
print('soft_output:\n', soft_output)
# 在softmax的基础上取log,注意下面的log其实是ln
log_output = torch.log(soft_output)
print('log_output:\n', log_output)
# 对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n', logsoftmax_output)
# pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func = nn.NLLLoss()
nlloss_output = nllloss_func(logsoftmax_output, y_target)
print('nlloss_output:\n', nlloss_output)
# 直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss = nn.CrossEntropyLoss()
crossentropyloss_output = crossentropyloss(x_input, y_target)
print('crossentropyloss_output:\n', crossentropyloss_output)
测试结果 :
x_input:
tensor([[-0.7195, 0.7233, 0.4887, -0.2147],
[ 2.0751, 0.4685, -0.9343, -0.6259],
[ 0.4000, -2.5932, -0.2298, 1.0522]])
soft_output:
tensor([[0.0977, 0.4135, 0.3270, 0.1618],
[0.7593, 0.1523, 0.0374, 0.0510],
[0.2855, 0.0143, 0.1521, 0.5481]])
log_output:
tensor([[-2.3260, -0.8832, -1.1178, -1.8212],
[-0.2754, -1.8820, -3.2848, -2.9764],
[-1.2535, -4.2467, -1.8833, -0.6013]])
logsoftmax_output:
tensor([[-2.3260, -0.8832, -1.1178, -1.8212],
[-0.2754, -1.8820, -3.2848, -2.9764],
[-1.2535, -4.2467, -1.8833, -0.6013]])
nlloss_output:
tensor(1.8071)
crossentropyloss_output:
tensor(1.8071)
# 1.8071 = (0.8832+3.2848+1.2535)/3
import torch
import torch.nn as nn
# batch_size=2, 预测3个对象、4类别
x_input = torch.randn(2, 3, 4) # B,N,C, B对应batch_size N是预测的对象数 C为类别
print('x_input:\n', x_input)
# y_target的shape为(2, 3)
y_target = torch.tensor([[1, 2, 0],[0, 1, 3]])
# 计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func = nn.Softmax(dim=2) #dim=0代表通道,dim=1代表列,dim=2代表行
soft_output = softmax_func(x_input)
print('soft_output:\n', soft_output)
# 在softmax的基础上取log,注意下面的log其实是ln
log_output = torch.log(soft_output)
print('log_output:\n', log_output)
# 对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=2) #dim=0代表通道,dim=1代表列,dim=2代表行
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n', logsoftmax_output)
# pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func = nn.NLLLoss()
nlloss_output = nllloss_func(logsoftmax_output.permute(0, 2, 1), y_target) #(B,N,C)--->(B,C,N)
print('nlloss_output:\n', nlloss_output)
# 直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss = nn.CrossEntropyLoss()
crossentropyloss_output = crossentropyloss(x_input.permute(0, 2, 1), y_target) #(B,N,C)--->(B,C,N)
print('crossentropyloss_output:\n', crossentropyloss_output)
测试结果 :
x_input:
tensor([[[-0.1371, -0.2454, 0.3631, 1.6234],
[ 1.0808, -0.2933, -0.3308, 0.4915],
[-0.6203, 0.1524, 0.2854, 1.2199]],
[[ 0.7083, 0.6273, -0.2904, 0.6575],
[-0.3896, -2.0054, 0.7636, 0.1116],
[ 2.7022, 1.1479, -0.1215, -0.9459]]])
soft_output:
tensor([[[0.1068, 0.0959, 0.1761, 0.6212],
[0.4874, 0.1234, 0.1188, 0.2704],
[0.0838, 0.1814, 0.2072, 0.5276]],
[[0.3085, 0.2845, 0.1137, 0.2933],
[0.1662, 0.0330, 0.5265, 0.2743],
[0.7712, 0.1630, 0.0458, 0.0201]]])
log_output:
tensor([[[-2.2366, -2.3449, -1.7364, -0.4761],
[-0.7186, -2.0927, -2.1302, -1.3079],
[-2.4796, -1.7069, -1.5739, -0.6394]],
[[-1.1759, -1.2569, -2.1746, -1.2267],
[-1.7947, -3.4105, -0.6415, -1.2935],
[-0.2599, -1.8142, -3.0836, -3.9080]]])
logsoftmax_output:
tensor([[[-2.2366, -2.3449, -1.7364, -0.4761],
[-0.7186, -2.0927, -2.1302, -1.3079],
[-2.4796, -1.7069, -1.5739, -0.6394]],
[[-1.1759, -1.2569, -2.1746, -1.2267],
[-1.7947, -3.4105, -0.6415, -1.2935],
[-0.2599, -1.8142, -3.0836, -3.9080]]])
nlloss_output:
tensor(2.5749)
crossentropyloss_output:
tensor(2.5749)
# 2.5749 = (2.3449+2.1302+2.4796+1.1759+3.4105+3.9080)/6
http://t.csdn.cn/pWkYlhttp://t.csdn.cn/pWkYlhttp://t.csdn.cn/3rdhI
http://t.csdn.cn/3rdhI