目录
1. 默认的loss函数
2.添加损失
3. 多个损失函数
mmseg中的loss函数定义在mmseg/models/losses/_ int _.py中
在configs/models中可以更换为自己想要的loss
要添加新的损失函数,用户需要在mmseg/models/losses/my_loss.py中实现它。装饰weighted_loss器使每个元素的损失得以加权。
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
然后,用户需要将其添加到中mmseg/models/losses/__init__.py。
from .my_loss import MyLoss, my_loss
loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0,class_weight=[0.001,1]),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0,class_weight=[0.001,1])]