早上想花一个小时参照网上其他教程,修改模型结构,写一个手写识别数字的出来,结果卡在了这个上面,loss一直降不下来,然后我就去查看了一下CrossEntropyLoss的用法,毕竟分类问题一般都用这个。
引入一个库:
import torch
假如是一个四分类任务,batch为2(只是为了显示简单,举个例子罢了)
logists = torch.randn(2, 4, requires_grad=True)
print(logists)
其实根据这个模型预测出来就是, 第一个样本预测的类别是1, 第二个样本预测的类别是2。
这里我们假设模型足够好,都预测对了,那么其实target就是ground_truth。
target = logists.argmax(dim=-1)
通过查看官方文档 CrossEntropyLoss–PyTorch 1.10.1 document, 可以知道loss有两种算法。
target可one-hot也可以不one-hot。
定义损失函数:
crition = torch.nn.CrossEntropyLoss()
先来看个target_1d版的loss:
crition(logists, target)
再来看个target one-hot版的:
注意: 该版本在我的macbook python3.7.8, torch1.10.2的版本上没有问题, 但是在我的windows python3.7.6 torch1.9.1就出问题了!!! 因此稳妥起见还是直接用target比较好
先把target转为one
t_onehot = torch.nn.functional.one_hot(target, num_classes=4)
如何是one_hot, 要求target也是浮点类型的,所以t_onehot再调用float()转为浮点类型。
crition(logists, t_onehot.float())
最后发现两种方法其实算出来的loss都是0.5601
另外插一嘴,crossEntropyLoss也可以通过nll_loss实现(如果你去看torch.nn.crossEntropyLoss的源码就会发现官方就是使用torch.nn.functional.nll_loss实现的,只不过模型输出的logists值要先经过log_softmax
这里来个序列标注的例子。模型输出是 (batch_size, seq_len, hidden_dim)。
这里我演示用两种方法,方法一会更加简洁,但可读性不如方法二,看个人喜欢。
# 序列标注,把每个token分成一类
import torch
cel = torch.nn.CrossEntropyLoss()
# 也可以用以下的方式,结果一样
# cel = torch.nn.functional.cross_entropy
batch_size, seq_len, hidden_dim = 4, 28, 128
# output logits
x = torch.randn(batch_size, seq_len, hidden_dim)
# ground truth
gt = torch.ones(batch_size, seq_len).long()
print('method 1:')
print(cel(x.permute(0, 2, 1), gt))
print()
print('method 2:')
print(cel(x.view(-1, hidden_dim), gt.view(-1)))
可以看到两种方法一样。
图像分割例子,模型输出是 (batch_size, channel, height, width), 有多少个类别就有多少个channel, 通常医疗上的语义分割是2分类,因此输出channel为2。
这里我演示用两种方法,方法一会更加简洁,但可读性不如方法二,看个人喜欢。
# 语义分割
import torch
cel = torch.nn.CrossEntropyLoss()
# 也可以用以下的方式,结果一样
# cel = torch.nn.functional.cross_entropy
# batch_size, input_channel, ouput_channel, height, width
# input_channel: 3 代表彩色图, output_channel: 2 代表语义分割二分类
b, ic, oc, h, w = 4, 3, 2, 28, 28
# output logits, predict: x.argmax(dim=1)
x = torch.randn(b, oc, h, w)
# ground truth
gt = torch.ones(b, 1, h, w).long()
# 这里可以把 h * w 这种图像的二维数据展平成一维数据,就和序列标注一样了
print('method 1:')
print(cel(x.view(b, oc, -1), gt.squeeze().view(b, -1)))
print()
print('method 2:')
print(cel(x.permute(0, 2, 3, 1).reshape(-1, oc), gt.squeeze().reshape(-1)))