关于pytorch损失函数的摸索

关于pytorch损失函数的摸索_第1张图片
以上是总结,然后看一下实际代码中怎么用
先看nn.CrossEntropyLoss()的用法

import numpy as np
A = range(12)
A = np.asarray(A)
A = A.reshape(4,3)
A = torch.from_numpy(A).float()
A.shape

A的形状为:torch.Size([4, 3])

label = np.ones((4,1))
label = torch.from_numpy(label).long()
label.shape

label的形状为:torch.Size([4, 1])

import torch.nn as nn
loss = nn.CrossEntropyLoss()
loss(A,label)

然后就会出现这个错误:
RuntimeError: 1D target tensor expected, multi-target not supported

意思大概是你给的目标值应该是1D,我传入的虽然是一个4*1的矩阵,但是它是一个二维矩阵。我得输入一个一维的标量,也就是形状为(n,),使用flatten()将矩阵变成标量。
修改后的完整代码

import numpy as np
A = range(12)
A = np.asarray(A)
A = A.reshape(4,3)
A = torch.from_numpy(A).float()
print(A.shape)
label = np.ones((4,1)).flatten()#修改的地方
label = torch.from_numpy(label).long()
print(label.shape)
import torch.nn as nn
loss = nn.CrossEntropyLoss()
loss(A,label)

输出为:torch.Size([4, 3])
torch.Size([4])
结果为:tensor(1.4076)

再来看nn.BCEWithLogitsLoss()

B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
B.shape

torch.Size([4, 1])

目标值还是用之前的label

label = np.ones((4,1)).flatten()
label = torch.from_numpy(label).long()
label.shape

torch.Size([4])

import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)

但是会出现错误:
ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 1]))
诶?它又要我们将target的形状和输入的B保持一致了。
行,那我们把flatten再给它去了。

B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label).long()
print(label.shape)

然后还有错:
RuntimeError: result type Float can’t be cast to the desired output type Long

那就再把long去掉,最终的正确代码:

B = np.random.randn(4,1)
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label)
print(label.shape)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)

输出:
torch.Size([4, 1])
torch.Size([4, 1])
结果:tensor(0.7415, dtype=torch.float64)

然后我又试了一下,把B和label都以1D标量的形式传入,发现也可以成功得到结果,也就是说只要两者的形状保持一致即可。
下面的写法也可。

B = np.random.randn(4,1).flatten()
B = torch.from_numpy(B).float()
print(B.shape)
label = np.ones((4,1))
label = torch.from_numpy(label).flatten()
print(label.shape)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)
import torch.nn as nn
loss = nn.BCEWithLogitsLoss()
loss(B,label)

你可能感兴趣的:(pytorch,深度学习)