用PyTorch编写自定义的损失函数(Custom Loss Function)

1. 什么是损失函数(Loss Function)

损失函数在训练AI模型的过程中,用于计算模型预测值与真实值之间的误差(Error),又称误差函数(Error function)

损失函数(Loss Function)

2. PyTorch中内建的损失函数

在torch.nn中内建了很多常用的损失函数,依据用途,可以分为三类:

  • 用于回归问题(Regression loss):回归损失主要关注连续值,例如: L1范数损失(L1Loss), 均方误差损失(MSELoss)等。
  • 用于分类问题(Classification loss):分类损失函数处理离散值,例如,交叉熵损失(CrossEntropyLoss), 二进制交叉熵损失(BCELoss)等。
  • 用于排名问题(Ranking loss): 排名损失用于预测输入样本之间的相对距离,例如,MarginRankingLoss,TripletMarginLoss等。

3. 使用PyTorch编写自定义损失函数

3.1 何时自己编写损失函数

当任务不再是单一的回归或者分类,PyTorch的内建损失函数无法满足任务需求时,需要开发者自己手动编写损失函数,例如:YOLOv5的total_loss为

total_loss = class loss + bbox loss+ object loss

就需要手动编写。

3.2 用PyTorch编写损失函数模板

方案1 -- 编写一个继承nn.Module的Loss函数类,并在forward()方法中实现loss计算:

class MyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))

方案2 -- 编写一个Loss函数类,并在_ call _方法中实现loss计算:

class MyLoss():
    def __init__(self):
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))

以上两种Loss函数类的使用方式都是:

loss_fn = MyLoss() # 实例化Loss函数
loss = loss_fn(preds, targets) # 计算loss值,即预测值与真实值之间的误差
...
loss.backward() #

在实际开发工作中,以上两种实现方案都非常常见,运行效率几乎一致,无性能优劣之分;由于loss函数并没有需要学习的参数,所以使用哪一种主要看开发者习惯的编程风格(coding style)。

参考资料:

  1. [Solved] What is the correct way to implement custom loss function?
  2. A-Collection-of-important-tasks-in-pytorch

你可能感兴趣的:(用PyTorch编写自定义的损失函数(Custom Loss Function))