当前MM版本为2.23.0
以faster_rcnn_r50_fpn.py
文件为例
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048], #一般使用resnet50输出的4个特征图。4个特征图的分辨率依次为原始图片的1/4, 1/8, 1/16, 1/32,通道数依次为256,512,1024,2048。
out_channels=256,
num_outs=5),
整个模型结构如下图:
下面代码构建了上图中FPN模块中的self.lateral_convs
和self.fpn_convs
部分
self.lateral_convs = nn.ModuleList() # 存1x1卷积
self.fpn_convs = nn.ModuleList() # 存3x3卷积
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from ..builder import NECKS
@NECKS.register_module()
class FPN(BaseModule):
r"""Feature Pyramid Network.
This is an implementation of paper `Feature Pyramid Networks for Object
Detection `_.
Args:
in_channels (list[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, it is equivalent to `add_extra_convs='on_input'`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: dict(mode='nearest').
init_cfg (dict or list[dict], optional): Initialization config dict.
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""
def __init__(self,
in_channels, # 每个尺度的输入通道数,也是backbone的输出通道数。
out_channels, # fpn的输出通道数,所有尺度的输出通道数相同,都是一个值。
num_outs, # 输出stage的个数.(可以加额外的层,num_outs 不一定等于 in_channels)。
start_level=0, # 使用backbone的起始stage索引,默认为0。
end_level=-1, # 使用backbone的终止stage索引,默认为-1,代表到最后一层(包括)权使用
add_extra_convs=False, # 可以bool或者str
# 如果是True,它相当于`add_extra_convs='on_input'`。
# 如果是str,它指定了额外convs的源特征图。
# on_input: 最高层的feature map作为extra的输入
# on_lateral: 最高层的lateral结果作为extra的输入
# on_output: 最高层的经过FPN的feature map作为extra的输入
relu_before_extra_convs=False, # 是否在extra conv前使用relu.(默认值:False)
no_norm_on_lateral=False, # 是否对FPN的1x1卷积使用bn.(默认值为False)
conv_cfg=None, # 构建conv层的字典,默认为None
norm_cfg=None, # 构建bn层的字典,默认为None
act_cfg=None, # 构建激活函数的字典,默认为None
upsample_cfg=dict(mode='nearest'), # 构建上采样的字典,默认为’nearest‘
init_cfg=dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super(FPN, self).__init__(init_cfg)
assert isinstance(in_channels, list) # backbone的输出是否为list
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False # 是否开启fp16
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1: # 如果是-1表示使用backbone的最后一个feature map,作为最终的索引
self.backbone_end_level = self.num_ins #self.backbone_end_level 和 end_level同为backbone的终止stage索引
assert num_outs >= self.num_ins - start_level # 因为有extra_conv的存在,所以存在num_outs > self.num_ins - start_level的情况
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
self.add_extra_convs = 'on_input'
self.lateral_convs = nn.ModuleList() # 存1x1卷积
self.fpn_convs = nn.ModuleList() # 存3x3卷积
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)
@auto_fp16()
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels)
# build laterals
# backbone的每个stage的输出经过1x1的卷积
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
# fix runtime error of "+=" inplace operation in PyTorch 1.10
laterals[i - 1] = laterals[i - 1] + F.interpolate(
laterals[i], **self.upsample_cfg)
else:
# 如果么有指定'scale_factor'则将上一层特征图大小放大到当前层的特征图尺寸大小。
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + F.interpolate(
laterals[i], size=prev_shape, **self.upsample_cfg)
# build outputs
# part 1: from original levels
# 将每个stage经过1x1卷积后的结果送到3x3卷积中得到输出
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
# 使用max pool 获得更高层的输出信息,如Faster R-CNN
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)