一个自底向上的线路,一个自顶向下的线路,横向连接(lateral connection)。图中放大的区域就是横向连接,这里1*1的卷积核的主要作用是减少卷积核的个数,也就是减少了feature map的个数,并不改变feature map的尺寸大小。
pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(inp, oup, stride = 1, leaky = 0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)
def conv_bn1x1(inp, oup, stride, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True)
)
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
leakey = 0
if(out_channels <= 64):
leakey = 0.1
self.output1 = conv_bn1x1(in_channels_list[0], out_channels, stride=1, leaky=leakey)
self.output2 = conv_bn1x1(in_channels_list[1], out_channels, stride=1, leaky=leakey)
self.output3 = conv_bn1x1(in_channels_list[2], out_channels, stride=1, leaky=leakey)
self.merge1 = conv_bn(out_channels, out_channels, leaky=leakey)
self.merge2 =conv_bn(out_channels, out_channels, leaky=leakey)
def forward(self, x):
input = list(x.values())
output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
output2 = output2 + up3
output2 = self.merge1(output2)
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
output1 = output1 + up2
output1 = self.merge2(output1)
out = [output1, output2, output3]
return out
上下文模块,增大感受野。
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(inp, oup, stride = 1, leaky = 0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)
def conv_bn_no_relu(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
)
class SSH(nn.Module):
def __init__(self, in_channels, out_channels):
super(SSH, self).__init__()
assert out_channels % 4 == 0
leaky = 0
if(out_channels <= 64):
leaky = 0.1
self.conv3x3 = conv_bn_no_relu(in_channels, out_channels // 2, stride=1)
self.conv5x5_1 = conv_bn(in_channels, out_channels // 4, stride=1, leaky=leaky)
self.conv5x5_2 = conv_bn_no_relu(out_channels // 4, out_channels // 4, stride=1)
self.conv7x7_2 = conv_bn(in_channels // 4, out_channels // 4, stride=1, leaky=leaky)
self.conv7x7_3 = conv_bn_no_relu(out_channels // 4, out_channels // 4, stride=1)
def forward(self, x):
conv3x3 = self.conv3x3(x)
conv5x5_1 = self.conv5x5_1(x)
conv5x5 = self.conv5x5_2(conv5x5_1)
conv7x7_2 = self.conv7x7_2(conv5x5_1)
conv7x7 = self.conv7x7_3(conv7x7_2)
out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
out = F.relu(out)
return out