FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果:
为了得到更好的分割结果,论文中提到了几种网络结构FCN-32s、FCN-16s、FCN-8s,如图所示:
特征提取骨干网络是一个VGG风格的网络,由多个vgg_block组成,每个vgg_block重复应用数个卷积层+ReLU层,再加上一个池化层。池化层会使宽高减半,也就是说在pool1、pool2、pool3、pool4、pool5之后,特征图的分辨率会分别变为原图像的 1 2 \frac {1} {2} 21、 1 4 \frac {1} {4} 41、 1 8 \frac {1} {8} 81、 1 16 \frac {1} {16} 161、 1 32 \frac {1} {32} 321。然后,FCN将VGG原来的两个全连接层替换成卷积层conv6-7,此时得到的结果仍然为原图的 1 32 \frac {1} {32} 321。
对于FCN-32s,将conv7通过转置卷积上采样32倍得到;
对于FCN-16s,将pool4与conv7上采样2倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的 1 16 \frac {1} {16} 161,再上采样16倍得到。
对于FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的 1 8 \frac {1} {8} 81,再上采样8倍得到。
为什么要采取FCN-16s、FCN-8s这两种融合方式呢?这是因为CNN的浅层卷积学习到的是局部特征(边缘、纹理),深层卷积学习语义特征,提高分类性能。而语义分割对细节特征精度要求较高,对于FCN-32s,直接上采样无法恢复细节信息,因此自然想到将浅层网络学习到的特征与深层特征相融合,有了FCN-16s、FCN-8s。实验证明,确实是FCN-8s细节特征最丰富,分割效果最好。
下面用Pytorch搭建FCN-8s的网络结构。首先是骨干网络,采用VGG16,注意将VGG原来的两个全连接层替换成卷积层conv6和conv7,并保存pool3、pool4、conv7的结果:
from torch import nn
from torchvision.models import vgg16
def vgg_block(num_convs, in_channels, out_channels):
blk = []
for i in range(num_convs):
if i == 0:
blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
else:
blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
blk.append(nn.ReLU(inplace=True))
blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 宽高减半
return blk
class VGG16(nn.Module):
def __init__(self, pretrained=True):
super(VGG16, self).__init__()
features = []
features.extend(vgg_block(2, 3, 64))
features.extend(vgg_block(2, 64, 128))
features.extend(vgg_block(3, 128, 256))
self.index_pool3 = len(features)
features.extend(vgg_block(3, 256, 512))
self.index_pool4 = len(features)
features.extend(vgg_block(3, 512, 512))
self.features = nn.Sequential(*features)
self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
# load pretrained params from torchvision.models.vgg16(pretrained=True)
if pretrained:
pretrained_model = vgg16(pretrained=pretrained)
pretrained_params = pretrained_model.state_dict()
keys = list(pretrained_params.keys())
new_dict = {}
for index, key in enumerate(self.features.state_dict().keys()):
new_dict[key] = pretrained_params[keys[index]]
self.features.load_state_dict(new_dict)
def forward(self, x):
pool3 = self.features[:self.index_pool3](x) # 1/8
pool4 = self.features[self.index_pool3:self.index_pool4](pool3) # 1/16
pool5 = self.features[self.index_pool4:](pool4) # 1/32
conv6 = self.relu(self.conv6(pool5)) # 1/32
conv7 = self.relu(self.conv7(conv6)) # 1/32
return pool3, pool4, conv7
然后是FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,再上采样8倍。
class FCN(nn.Module):
def __init__(self, num_classes, backbone='vgg'):
super(FCN, self).__init__()
if backbone == 'vgg':
self.features = VGG16()
self.scores1 = nn.Conv2d(4096, num_classes, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.scores2 = nn.Conv2d(512, num_classes, kernel_size=1)
self.scores3 = nn.Conv2d(256, num_classes, kernel_size=1)
self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8)
self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=4)
self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2)
def forward(self, x):
pool3, pool4, conv7 = self.features(x)
conv7 = self.relu(self.scores1(conv7)) # 1×1卷积将通道数映射为类别数
pool4 = self.relu(self.scores2(pool4)) # 1×1卷积将通道数映射为类别数
pool3 = self.relu(self.scores3(pool3)) # 1×1卷积将通道数映射为类别数
s = pool3 + self.upsample_2x(pool4) + self.upsample_4x(conv7) # 相加融合
out_8s = self.upsample_8x(s) # 8倍上采样
return out_8s
打印一下网络结构:
net = FCN(num_classes=21)
from torchsummary import summary
summary(net.cuda(), (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Conv2d-32 [-1, 4096, 7, 7] 2,101,248
ReLU-33 [-1, 4096, 7, 7] 0
Conv2d-34 [-1, 4096, 7, 7] 16,781,312
ReLU-35 [-1, 4096, 7, 7] 0
VGG16-36 [[-1, 256, 28, 28], [-1, 512, 14, 14], [-1, 4096, 7, 7]] 0
Conv2d-37 [-1, 21, 7, 7] 86,037
ReLU-38 [-1, 21, 7, 7] 0
Conv2d-39 [-1, 21, 14, 14] 10,773
ReLU-40 [-1, 21, 14, 14] 0
Conv2d-41 [-1, 21, 28, 28] 5,397
ReLU-42 [-1, 21, 28, 28] 0
ConvTranspose2d-43 [-1, 21, 28, 28] 1,785
ConvTranspose2d-44 [-1, 21, 28, 28] 7,077
ConvTranspose2d-45 [-1, 21, 224, 224] 28,245
================================================================
Total params: 33,736,562
Trainable params: 33,736,562
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 233.14
Params size (MB): 128.69
Estimated Total Size (MB): 362.41
----------------------------------------------------------------