全局卷积网络模块(Global Convolutional Network, GCN)是由清华大学旷视科技提出应用于语义分割任务改善模型性能的模块,其核心是利用非对称卷积获得大感受野来解决语义分割任务中像素的分类和定位问题,以达到更好的分割效果。相较于普通卷积,非对称卷积能够在少参数量的情况下达到大卷积核的特征映射效果。以一个3×3的普通卷积为例,其卷积的特征映射等于3×1和1×3两个非对称卷积的卷积结果。GCN模块由1×k+k×1和k×1+1×k两组非对称卷积构成,并通过稠密连接特征映射出大小为k×k的特征图,最后相加融合输出预测图像,其中k为卷积核大小,n为预测类别数(包含背景)。
self.branch1_0 = nn.Sequential(
nn.Conv2d(24, 4, (15, 1), 1, (7, 0)),
nn.Conv2d(4, 4, (1, 15), 1, (0, 7)))
self.branch1_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 15), 1, (0, 7)),
nn.Conv2d(4, 4, (15, 1), 1, (7, 0)))
def forward(self, x):
branch1_0 = self.branch1_0(x)
branch1_1 = self.branch1_1(x)
branch1 = branch1_0 + branch1_1
边界细化模块(Boundary Refinement,BR)是在全局卷积网络中提出的边界优化模块,其结构设计类似于残差结构,通过短接对预测结果进行进一步修正,从而提高目标对象的定位,达到优化边界的作用,BR模块由两个3×3卷积通过映射短接构成,将残差结构结果与原特征进行相加融合后,再输出边界增强后的预测结果。
self.br = nn.Sequential(
nn.Conv2d(4, 4, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(4, 4, 3, 1, 1))
def forward(self, x):
H, W = x.size(2), x.size(3)
the_two_features, low_level_features, the_three_features, the_four_features, x = self.backbone(x)
x = self.aspp(x) #32*32*256
x0 = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)),
mode='bilinear', align_corners=True) #32*32*256-128*128*256
x1 = torch.cat((x0, low_level_features), dim=1) #128*128*280
x2 = self.conv3(x1) #128*128*4
# -----------------------------------------#
# 输入128*128*24,输出128*128*4
branch1_0 = self.branch1_0(low_level_features)
branch1_1 = self.branch1_1(low_level_features)
branch1 = branch1_0 + branch1_1
branch = self.br(branch1)
result1 = branch + branch1 #128*128*4
# -----------------------------------------#
x3 = x2 + result1 #128*128*4
x4 = self.br(x3) #128*128*4
x5 = F.interpolate(x4, size=(the_two_features.size(2), the_two_features.size(3)),
mode='bilinear', align_corners=True) #128*128*4-256*256*4
# -----------------------------------------#
# 输入256*256*16,输出256*256*4
branch2_0 = self.branch2_0(the_two_features)
branch2_1 = self.branch2_1(the_two_features)
branch2 = branch2_0 + branch2_1
branchh = self.br(branch2)
result2 = branchh + branch2 #256*256*4
# -----------------------------------------#
x6 = x5 + result2 #256*256*4
x7 = self.br(x6) #256*256*4
x8 = F.interpolate(x7, size=(H, W), mode='bilinear', align_corners=True)
return x8
self.branch1_0 = nn.Sequential(
nn.Conv2d(24, 4, (3, 1), 1, (1, 0)),
nn.Conv2d(4, 4, (1, 3), 1, (0, 1)))
self.branch1_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 3), 1, (0, 1)),
nn.Conv2d(4, 4, (3, 1), 1, (1, 0)))
self.branch2_0 = nn.Sequential(
nn.Conv2d(24, 4, (7, 1), 1, (3, 0)),
nn.Conv2d(4, 4, (1, 7), 1, (0, 3)))
self.branch2_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 7), 1, (0, 3)),
nn.Conv2d(4, 4, (7, 1), 1, (3, 0)))
self.branch3_0 = nn.Sequential(
nn.Conv2d(24, 4, (11, 1), 1, (5, 0)),
nn.Conv2d(4, 4, (1, 11), 1, (0, 5)))
self.branch3_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 11), 1, (0, 5)),
nn.Conv2d(4, 4, (11, 1), 1, (5, 0)))
self.branch4_0 = nn.Sequential(
nn.Conv2d(24, 4, (15, 1), 1, (7, 0)),
nn.Conv2d(4, 4, (1, 15), 1, (0, 7)))
self.branch4_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 15), 1, (0, 7)),
nn.Conv2d(4, 4, (15, 1), 1, (7, 0)))
self.br = nn.Sequential(
nn.Conv2d(4, 4, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(4, 4, 3, 1, 1))
self.cat_conv = nn.Sequential(
nn.Conv2d(296, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(128, 4, 3, stride=1, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
self.branch1_0 = nn.Sequential(
nn.Conv2d(24, 4, (3, 1), 1, (1, 0)),
nn.Conv2d(4, 4, (1, 3), 1, (0, 1)))
self.branch1_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 3), 1, (0, 1)),
nn.Conv2d(4, 4, (3, 1), 1, (1, 0)))
self.branch2_0 = nn.Sequential(
nn.Conv2d(24, 4, (7, 1), 1, (3, 0)),
nn.Conv2d(4, 4, (1, 7), 1, (0, 3)))
self.branch2_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 7), 1, (0, 3)),
nn.Conv2d(4, 4, (7, 1), 1, (3, 0)))
self.branch3_0 = nn.Sequential(
nn.Conv2d(24, 4, (11, 1), 1, (5, 0)),
nn.Conv2d(4, 4, (1, 11), 1, (0, 5)))
self.branch3_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 11), 1, (0, 5)),
nn.Conv2d(4, 4, (11, 1), 1, (5, 0)))
self.branch4_0 = nn.Sequential(
nn.Conv2d(24, 4, (15, 1), 1, (7, 0)),
nn.Conv2d(4, 4, (1, 15), 1, (0, 7)))
self.branch4_1 = nn.Sequential(
nn.Conv2d(24, 4, (1, 15), 1, (0, 7)),
nn.Conv2d(4, 4, (15, 1), 1, (7, 0)))
self.br = nn.Sequential(
nn.Conv2d(4, 4, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(4, 4, 3, 1, 1))
def forward(self, x):
H, W = x.size(2), x.size(3)
# -----------------------------------------#
# 获得两个特征层
# low_level_features: 浅层特征-进行卷积处理
# x : 主干部分-利用ASPP结构进行加强特征提取
# -----------------------------------------#
low_level_features, x = self.backbone(x)
x = self.aspp(x) #32*32*256
x0 = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)),
mode='bilinear', align_corners=True) #32*32*256-128*128*256
# -----------------------------------------#
# 输入128*128*24,输出128*128*4
branch1_0 = self.branch1_0(low_level_features)
branch1_1 = self.branch1_1(low_level_features)
branch1 = branch1_0 + branch1_1
br1 = self.br(branch1)
result1 = br1 + branch1 # 128*128*4
# -----------------------------------------#
# 输入128*128*24,输出128*128*4
branch2_0 = self.branch2_0(low_level_features)
branch2_1 = self.branch2_1(low_level_features)
branch2 = branch2_0 + branch2_1
br2 = self.br(branch2)
result2 = br2 + branch2 # 128*128*4
# -----------------------------------------#
# 输入128*128*24,输出128*128*4
branch3_0 = self.branch3_0(low_level_features)
branch3_1 = self.branch3_1(low_level_features)
branch3 = branch3_0 + branch3_1
br3= self.br(branch3)
result3 = br3 + branch3 # 128*128*4
# -----------------------------------------#
# 输入128*128*24,输出128*128*4
branch4_0 = self.branch4_0(low_level_features)
branch4_1 = self.branch4_1(low_level_features)
branch4 = branch4_0 + branch4_1
br4 = self.br(branch4)
result4 = br4 + branch4 # 128*128*4
# -----------------------------------------#
x1 = torch.cat((result1, result2, result3, result4), dim=1) # 128*128*16
x2 = torch.cat((x1, low_level_features, x0), dim=1) # 128*128*296
x3 = self.cat_conv(x2) # 128*128*296-128*128*4
x4 = F.interpolate(x3, size=(H, W), mode='bilinear', align_corners=True) # 128*128*4-512*512*4
return x4