maskrcnn-benchmark 代码详解之 backbone.py

前言

在backbone.py文件中,定义了各种不同的backbone结构,并使用Registry装饰器类来实现构造这些backbone结构函数的调用。指的一提的是,backbone.py将不同的模块搭建出拥有不同功能的backbone结构,为边框预测等操作提供各自合适的特征提取网络;

其中Resnet的第2个stage非常重要,因为第一个阶段就是对原始图像的一次粗糙的特征提取,从第二阶段开始往后,特征图的大小缩小两倍,但是每个stage的输入输出通道扩大两倍,这是典型的用通道换面积。详细代码解释如下:

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import OrderedDict

from torch import nn

from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform
from . import fpn as fpn_module
from . import resnet


# 定义不同的backbone并将其保存在registry模型字典中,保存的名称与resnet.py中的网络结构相同,这样就可以构造指定的网络结构
# 例如:R-50-C4意思是Resnet50前4个stage的结构,而在这个名称和resnet.py中的指定网络结构是一致的,这个R-50-C4连接这两部分

# todo 单纯用Resnet构造的backbone
@registry.BACKBONES.register("R-50-C4")
@registry.BACKBONES.register("R-50-C5")
@registry.BACKBONES.register("R-101-C4")
@registry.BACKBONES.register("R-101-C5")
def build_resnet_backbone(cfg):
    # 获得Resnet的主干网络
    body = resnet.ResNet(cfg)
    # 初始化网络结构
    model = nn.Sequential(OrderedDict([("body", body)]))
    # 指定输出层的输出通道数
    model.out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
    return model


# todo 使用Resnet和FPN相结合的backbone
@registry.BACKBONES.register("R-50-FPN")
@registry.BACKBONES.register("R-101-FPN")
@registry.BACKBONES.register("R-152-FPN")
def build_resnet_fpn_backbone(cfg):
    # 获得Resnet的主干网络
    body = resnet.ResNet(cfg)
    # 指定输入层的输入通道数,为Resnet第2阶段的输出通道,第2个stage非常重要,
    # 因为第一个阶段就是对原始图像的一次粗糙的特征提取,从第二阶段开始往后,特征图的大小缩小两倍,
    # 但是每个stage的输入输出通道扩大两倍,这是典型的用通道换面积
    in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    # 指定输出层的输出通道数,用户自定义的
    out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
    # 构造FPN网络结构,in_channels_stage2以后每一层的通道都是扩大两倍
    fpn = fpn_module.FPN(
        in_channels_list=[
            in_channels_stage2,
            in_channels_stage2 * 2,
            in_channels_stage2 * 4,
            in_channels_stage2 * 8,
        ],
        # FPN最后的输出层输出通道是自定义的,使得网络跑通即可
        out_channels=out_channels,
        # 指定卷积方式
        conv_block=conv_with_kaiming_uniform(
            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
        ),
        # 指定最后一层的输出是否需要再经过池化等操作,这里是最大值池化
        top_blocks=fpn_module.LastLevelMaxPool(),
    )
    # 初始化网络结构
    model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
    # 指定输出层的输出通道数
    model.out_channels = out_channels
    return model


@registry.BACKBONES.register("R-50-FPN-RETINANET")
@registry.BACKBONES.register("R-101-FPN-RETINANET")
def build_resnet_fpn_p3p7_backbone(cfg):
    # 获得Resnet的主干网络
    body = resnet.ResNet(cfg)
    # 指定输入层的输入通道数,为Resnet第2阶段的输出通道,第2个stage非常重要,
    # 因为第一个阶段就是对原始图像的一次粗糙的特征提取,从第二阶段开始往后,特征图的大小缩小两倍,
    # 但是每个stage的输入输出通道扩大两倍,这是典型的用通道换面积
    in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    # 指定输出层的输出通道数,用户自定义的
    out_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
    # 获得p6p7层,也就是在FPN的p5层上再加两层
    in_channels_p6p7 = in_channels_stage2 * 8 if cfg.MODEL.RETINANET.USE_C5 \
        else out_channels
    fpn = fpn_module.FPN(
        in_channels_list=[
            0,
            in_channels_stage2 * 2,
            in_channels_stage2 * 4,
            in_channels_stage2 * 8,
        ],
        out_channels=out_channels,
        conv_block=conv_with_kaiming_uniform(
            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
        ),
        # 指定最后一层的输出是否需要再经过池化等操作,RETINANET不需要池化,他是在FPN上加了p6p7层,是1阶段的目标检测模型
        top_blocks=fpn_module.LastLevelP6P7(in_channels_p6p7, out_channels),
    )
    # 初始化网络结构
    model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
    # 指定输出层的输出通道数
    model.out_channels = out_channels
    return model


# todo 构造主干网络结构
def build_backbone(cfg):
    assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
        "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
            cfg.MODEL.BACKBONE.CONV_BODY
        )
    # 通过调用装饰其类来构造本文件中定义的各种backbone结构
    return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)

 

你可能感兴趣的:(maskrcnn,benchmark,maskrcnn,benchmark,代码详解,目标检测,Resnet,backbone.py)