最近在用交叉熵损失函数,但是却频频出现bug,这里把坑都记录一下,避免以后又再一次掉进去,也希望能帮助到掉进去的人出来。
# 正确示例
loss = torch.nn.CrossEntropyLoss()
loss = loss(predict, target.long())
# 或者
loss = torch.nn.CrossEntropyLoss()(predict, target.long())
# 错误示例
loss = torch.nn.CrossEntropyLoss(predict, target.long())
如果错用了上述“错误示例”,就会报错"RuntimeError: Boolean value of Tensor with more than one value is ambiguous"
其次,要说的是传给对象的参数,第一个predict是网络的直接输出,是含有正数、负数的一些乱七八糟的数,如果是3D的,predict.shape应该是[B, Classes, D, H, W],如果是2D的,predict.shape应该是[B, Classes, H, W],其中B是batch_size,Classes是类别数,是几分类就是几(考虑背景)。第二个参数target是标签,是索引标签,3D中的大小是[B, D, H, W],2D中的大小是[B, H, W],可能你会问了,predict和target的大小不一样,怎么计算loss呢?,因为torch.nn.CrossEntropyLoss()内部封装了将target转换成和predict大小一样的onehot编码的代码,所以只需要传进去索引编码就可以,不用自己转换成onehot编码
代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0]) # onehot编码格式
label2 = torch.LongTensor([1, 0, 0]) # 索引标签
loss = torch.nn.CrossEntropyLoss()(predict, label2)
同时需要注意,传进去的target应该是LongTensor类型的,如果不是,需要强制转换一下,否则就会报错,应该是这个错误“”RuntimeError: Expected object of type torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 ‘target’“”
2. torch.nn.BCELoss()
代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0])
loss = torch.nn.BCELoss()(torch.nn.Sigmoid(predict), labels)
由上述代码可以看出,torch.nn.BCELoss()没有封装将索引编码转换成onehot编码的代码,需要自己现将索引编码转换成onehot编码,然后在上传参数。并且,predict参数也不再是网络的直接输出,而是经过sigmoid之后的,某个维度上所有值之和为1。
此外,将索引编码转换成onehot编码,可以使用**torch.nn.functional.onehot()**函数。
3. torch.nn.BCEWithLogitsLoss()
代码示例
predict = torch.randn(3, 2)
labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0])
loss = torch.nn.BCEWithLogitsLoss()(predict, labels)
再看torch.nn.BCEWithLogitsLoss()这个类,它需要上传的参数是网络的直接输出和onehot编码格式的target。
4. torch.nn.functional.binary_cross_entropy()
需要注意的是,上述前三个都是类,传参数时需要先创建对象,然后将参数传给对象,但是torch.nn.functional里面的是函数,直接传参数即可。要注意这一点,避免不必要的bug。
代码示例
predict = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(predict), target)
这是官方文档的一个例子,首先,该函数适用于二分类的交叉熵损失函数计算,其次对于predict参数,是需要经过sigmoid之后的,target是onehot编码格式的,因为target和predict的大小是一样的。如果是3D,那么predict和target的shape都应该是[B, Classes, D, H, W],如果是2D,那么它们的大小都应该是[B, Classes, H, W],其中各个字母的含义和上述相同。
注:不知道该怎么用时,查官方文档真的很有用,很有用,很有用!