FPN
import torch.nn as nn
import torch.nn.functional as F
from ..module.conv import ConvModule
from ..module.init_weights import xavier_init
class FPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
conv_cfg=None,
norm_cfg=None,
activation=None
):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.fp16_enabled = False
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.lateral_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
activation=activation,
inplace=False)
self.lateral_convs.append(l_conv)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='bilinear')
outs = [
laterals[i] for i in range(used_backbone_levels)
]
return tuple(outs)
PAN
import torch.nn as nn
import torch.nn.functional as F
from ..module.conv import ConvModule
from .fpn import FPN
class PAN(FPN):
"""Path Aggregation Network for Instance Segmentation.
This is an implementation of the `PAN in Path Aggregation Network
`_.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool): Whether to add conv layers on top of the
original feature maps. Default: False.
extra_convs_on_inputs (bool): Whether to apply extra conv on
the original feature from the backbone. Default: False.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
conv_cfg=None,
norm_cfg=None,
activation=None):
super(PAN,
self).__init__(in_channels, out_channels, num_outs, start_level,
end_level, conv_cfg, norm_cfg, activation)
self.init_weights()
def forward(self, inputs):
"""Forward function."""
assert len(inputs) == len(self.in_channels)
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
laterals[i - 1] += F.interpolate(
laterals[i], scale_factor=2, mode='bilinear')
inter_outs = [
laterals[i] for i in range(used_backbone_levels)
]
for i in range(0, used_backbone_levels - 1):
inter_outs[i + 1] += F.interpolate(inter_outs[i], scale_factor=0.5, mode='bilinear')
outs = []
outs.append(inter_outs[0])
outs.extend([
inter_outs[i] for i in range(1, used_backbone_levels)
])
return tuple(outs)
BiFPN
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class DepthwiseConvBlock(nn.Module):
"""
Depthwise seperable convolution.
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
super(DepthwiseConvBlock, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
padding, dilation, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, padding=0, dilation=1, groups=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
self.act = nn.ReLU()
def forward(self, inputs):
x = self.depthwise(inputs)
x = self.pointwise(x)
x = self.bn(x)
return self.act(x)
class ConvBlock(nn.Module):
"""
Convolution block with Batch Normalization and ReLU activation.
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
self.act = nn.ReLU()
def forward(self, inputs):
x = self.conv(inputs)
x = self.bn(x)
return self.act(x)
class BiFPNBlock(nn.Module):
"""
Bi-directional Feature Pyramid Network
"""
def __init__(self, feature_size=64, epsilon=0.0001):
super(BiFPNBlock, self).__init__()
self.epsilon = epsilon
self.p3_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_td = DepthwiseConvBlock(feature_size, feature_size)
self.p5_td = DepthwiseConvBlock(feature_size, feature_size)
self.p6_td = DepthwiseConvBlock(feature_size, feature_size)
self.p4_out = DepthwiseConvBlock(feature_size, feature_size)
self.p5_out = DepthwiseConvBlock(feature_size, feature_size)
self.p6_out = DepthwiseConvBlock(feature_size, feature_size)
self.p7_out = DepthwiseConvBlock(feature_size, feature_size)
self.w1 = nn.Parameter(torch.Tensor(2, 4))
self.w1_relu = nn.ReLU()
self.w2 = nn.Parameter(torch.Tensor(3, 4))
self.w2_relu = nn.ReLU()
def forward(self, inputs):
p3_x, p4_x, p5_x, p6_x, p7_x = inputs
w1 = self.w1_relu(self.w1)
w1 /= torch.sum(w1, dim=0) + self.epsilon
w2 = self.w2_relu(self.w2)
w2 /= torch.sum(w2, dim=0) + self.epsilon
p7_td = p7_x
p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_x, scale_factor=2))
p5_td = self.p5_td(w1[0, 1] * p5_x + w1[1, 1] * F.interpolate(p6_x, scale_factor=2))
p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p5_x, scale_factor=2))
p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_x, scale_factor=2))
p3_out = p3_td
p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out))
p5_out = self.p5_out(w2[0, 1] * p5_x + w2[1, 1] * p5_td + w2[2, 1] * nn.Upsample(scale_factor=0.5)(p4_out))
p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] * nn.Upsample(scale_factor=0.5)(p5_out))
p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out))
return [p3_out, p4_out, p5_out, p6_out, p7_out]
class BiFPN(nn.Module):
def __init__(self, in_channels, out_channels=64, num_layers=2, epsilon=0.0001):
super(BiFPN, self).__init__()
self.p3 = nn.Conv2d(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
self.p4 = nn.Conv2d(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
self.p5 = nn.Conv2d(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
self.p6 = nn.Conv2d(in_channels[2], out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
bifpns = []
for _ in range(num_layers):
bifpns.append(BiFPNBlock(out_channels))
self.bifpn = nn.Sequential(*bifpns)
def forward(self, inputs):
c3, c4, c5 = inputs
p3_x = self.p3(c3)
p4_x = self.p4(c4)
p5_x = self.p5(c5)
p6_x = self.p6(c5)
p7_x = self.p7(p6_x)
features = [p3_x, p4_x, p5_x, p6_x, p7_x]
return self.bifpn(features)