将FCOS中fpn.py替换为下面代码即可。
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFF, self).__init__()
self.level = level
# self.dim = [512, 256, 256]
self.dim = [256, 256, 256]
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
# self.expand = add_conv(self.inter_dim, 1024, 3, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
elif level == 1:
# self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.compress_level_0 = add_conv(256, self.inter_dim, 1, 1)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
# self.expand = add_conv(self.inter_dim, 512, 3, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
elif level == 2:
# self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.compress_level_0 = add_conv(256, self.inter_dim, 1, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
# when adding rfb, we use half number of channels to save memory
compress_c = 8 if rfb else 16
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(
compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
self.vis = vis
def forward(self, x_level_0, x_level_1, x_level_2):
# import ipdb
# ipdb.set_trace()
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(
x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=4, mode='nearest')
level_1_resized = F.interpolate(
x_level_1, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat(
(level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\
level_1_resized * levels_weight[:, 1:2, :, :] +\
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
def add_conv(in_ch, out_ch, ksize, stride, leaky=True):
"""
Add a conv2d / batchnorm / leaky ReLU block.
Args:
in_ch (int): number of input channels of the convolution layer.
out_ch (int): number of output channels of the convolution layer.
ksize (int): kernel size of the convolution layer.
stride (int): stride of the convolution layer.
Returns:
stage (Sequential) : Sequential layers composing a convolution block.
"""
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
out_channels=out_ch, kernel_size=ksize, stride=stride,
padding=pad, bias=False))
stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
if leaky:
stage.add_module('leaky', nn.LeakyReLU(0.1))
else:
stage.add_module('relu6', nn.ReLU6(inplace=True))
return stage
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(
x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(
level_0_compressed, scale_factor=4, mode='nearest')
level_1_resized = F.interpolate(
x_level_1, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat(
(level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\
level_1_resized * levels_weight[:, 1:2, :, :] +\
level_2_resized * levels_weight[:, 2:, :, :]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
class FPN(nn.Module):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(
self, in_channels_list, out_channels, conv_block, top_blocks=None
):
"""
Arguments:
in_channels_list (list[int]): number of channels for each feature map that
will be fed
out_channels (int): number of channels of the FPN representation
top_blocks (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list
"""
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_module(inner_block, inner_block_module)
self.add_module(layer_block, layer_block_module)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
self.top_blocks = top_blocks
self.asff_level0 = ASFF(level=0)
self.asff_level1 = ASFF(level=1)
self.asff_level2 = ASFF(level=2)
def forward(self, x):
"""
Arguments:
x (list[Tensor]): feature maps for each feature level.
Returns:
results (tuple[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first.
"""
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
results = []
results.append(getattr(self, self.layer_blocks[-1])(last_inner))
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
if not inner_block:
continue
# inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
inner_top_down = F.interpolate(
last_inner, size=(
int(inner_lateral.shape[-2]), int(inner_lateral.shape[-1])),
mode='nearest'
)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
results_after_asff = [None, None, None]
results_after_asff[2] = self.asff_level0(
results[2], results[1], results[0])
results_after_asff[1] = self.asff_level1(
results[2], results[1], results[0])
results_after_asff[0] = self.asff_level2(
results[2], results[1], results[0])
# import ipdb
# ipdb.set_trace()
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
# results.extend(last_results)
results_after_asff.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
# results.extend(last_results)
results_after_asff.extend(last_results)
return tuple(results_after_asff)
class LastLevelMaxPool(nn.Module):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]