pytorch笔记:11) 多标签多分类中损失函数选择及样本不均衡问题

问题来源:
解决多标签多分类中损失函数选择及样本不均衡问题的2个帖子
https://cloud.tencent.com/developer/ask/226097
https://discuss.pytorch.org/t/multi-label-multi-class-class-imbalance/37573

主要弄明白nn.BCEWithLogitsLossnn.MultiLabelSoftMarginLoss有啥区别,下面用一个栗子来测试下,顺便测试了上面提及的自定义损失函数

from torch import nn  
import torch  
 
#重新封装的多标签损失函数
class WeightedMultilabel(nn.Module):  
    def __init__(self, weights: torch.Tensor):  
        super(WeightedMultilabel, self).__init__()  
        self.cerition = nn.BCEWithLogitsLoss(reduction='none')  
        self.weights = weights  
  
    def forward(self, outputs, targets):  
        loss = self.cerition(outputs, targets)  
        return (loss * self.weights).mean()  
  
x=torch.randn(3,4)  
y=torch.randn(3,4)  
#损失函数对应类别的权重
w=torch.tensor([10,2,15,20],dtype=torch.float)  
#测试不同的损失函数
criterion_BCE=nn.BCEWithLogitsLoss(w)  
criterion_mult=WeightedMultilabel(w)  
criterion_mult2=nn.MultiLabelSoftMarginLoss(w)  
  
loss1=criterion_BCE(x,y)  
loss2=criterion_mult(x,y)  
loss3=criterion_mult2(x,y)  
  
print(loss1)  
print(loss2)  
print(loss3)  
  
# tensor(7.8804)  
# tensor(7.8804)  
# tensor(7.8804)

结论:从上面的结果可以看到,3个损失函数其实是等价的- -

你可能感兴趣的:(机器·深度学习)