mmsegment定制模型(八)

目录

1、定制优化器

1.1、优化器注册类

1.2、优化器定义

2、开发新组件

2.1、添加新的主干(backbone)

2.2、添加新的头(head)

2.3、增加新的损失


1、定制优化器

1.1、优化器注册类

mmsegmentation中未对优化器进行封装,真实的封装及注册类在mmcv中完成。我们可以在schedule默认配置中看到优化器的字典型配置。

在mmseg/api/train.py中,由build_optimizer完成构造。

进一步代码跟进,在mmcv中实现该注册类。

mmsegment定制模型(八)_第1张图片

所以,不用纠结,工程代码中为什么字典型优化器配置,如何转为pytorch的优化器类。

1.2、优化器定义

假设定义一个名为优化MyOptimizer,其中有参数a,b和c。需要首先在文件中实现新的优化器,例如,在mmseg/core/optimizer/my_optimizer.py:

from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer

@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
    def __init__(self, a, b, c):
        pass

然后在其中添加此模块,mmseg/core/optimizer/__init__.py以便注册表将找到新模块并将其添加:

from .my_optimizer import MyOptimizer

然后,您可以MyOptimizer在optimizer配置文件的字段中使用。在配置中,优化器由字段定义,optimizer如下所示:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

要使用自己的优化器,可以将字段更改为:

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

我们已经支持使用由PyTorch实现的所有优化器,唯一的修改就是更改optimizer配置文件的字段。例如,如果您要使用ADAM,尽管性能会下降很多,但可以进行如下修改。

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

用户可以按照PyTorch的API文档直接设置参数。

2、开发新组件

MMSegmentation中主要有2种类型的组件。

1)backbone:通常是卷积网络堆栈以提取特征图,例如ResNet,HRNet。

2)head:用于语义分割图解码的组件。

2.1、添加新的主干(backbone)

在这里,我们以MobileNet为例说明如何开发新组件。完整实现方式可以参考mmseg/models/backbone/mobilenet_v2.py。

1)创建一个新文件mmseg/models/backbones/mobilenet.py。以下三个类函数是必须实现的。

import torch.nn as nn
from ..registry import BACKBONES

@BACKBONES.register_module
class MobileNet(nn.Module):
    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass

2)在中导入模块mmseg/models/backbones/__init__.py。

from .mobilenet import MobileNet

3)在配置文件中使用它。

model = dict( ... 
    backbone=dict( type='MobileNet', arg1=xxx, arg2=xxx), ...

2.2、添加新的头(head)

在MMSegmentation中,为所有head提供了基类BaseDecodeHead。所有新实现的head都应从中派生。在这里,我们以PSPNet为例说明如何开发新的头,如下所示。

首先,在中添加一个新的head:mmseg/models/decode_heads/psp_head.py。PSPNet实现了用于解码的分割head。要实现head,基本上我们需要实现以下新模块的三个功能。完整参考:psp_head.py中PSPHead中完整实现。

@HEADS.register_module()
class PSPHead(BaseDecodeHead):
    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)

    def init_weights(self):
        pass

    def forward(self, inputs):
        pass

接下来,用户需要在其中添加模块,mmseg/models/decode_heads/__init__.py因此相应的注册表可以找到并加载它们。

如下配置PSPNet文件:

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', 
            use_sigmoid=False, loss_weight=1.0)
            ))

2.3、增加新的损失

假设要添加新的损失MyLoss。要添加新的损失函数,用户需要在mmseg/models/losses/my_loss.py中实现它。装饰weighted_loss器使每个元素的损失得以加权。

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss

然后,用户需要将其添加到中mmseg/models/losses/__init__.py。

from .my_loss import MyLoss, my_loss

要使用它,修改loss_xxx字段。然后,您需要修改head的loss_decode字段。 loss_weight可以用来平衡多重损失。

loss_decode=dict(type='MyLoss', loss_weight=1.0))

 

你可能感兴趣的:(#,mmsegment,python,pytorch,深度学习)