Pytorch搭建FCN网络

Pytorch搭建FCN网络

  • 前言
    • 原理
    • 代码实现

前言

FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果:
Pytorch搭建FCN网络_第1张图片

原理

为了得到更好的分割结果,论文中提到了几种网络结构FCN-32s、FCN-16s、FCN-8s,如图所示:
Pytorch搭建FCN网络_第2张图片
特征提取骨干网络是一个VGG风格的网络,由多个vgg_block组成,每个vgg_block重复应用数个卷积层+ReLU层,再加上一个池化层。池化层会使宽高减半,也就是说在pool1、pool2、pool3、pool4、pool5之后,特征图的分辨率会分别变为原图像的 1 2 \frac {1} {2} 21 1 4 \frac {1} {4} 41 1 8 \frac {1} {8} 81 1 16 \frac {1} {16} 161 1 32 \frac {1} {32} 321。然后,FCN将VGG原来的两个全连接层替换成卷积层conv6-7,此时得到的结果仍然为原图的 1 32 \frac {1} {32} 321
对于FCN-32s,将conv7通过转置卷积上采样32倍得到;
对于FCN-16s,将pool4与conv7上采样2倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的 1 16 \frac {1} {16} 161,再上采样16倍得到。
对于FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的 1 8 \frac {1} {8} 81,再上采样8倍得到。

为什么要采取FCN-16s、FCN-8s这两种融合方式呢?这是因为CNN的浅层卷积学习到的是局部特征(边缘、纹理),深层卷积学习语义特征,提高分类性能。而语义分割对细节特征精度要求较高,对于FCN-32s,直接上采样无法恢复细节信息,因此自然想到将浅层网络学习到的特征与深层特征相融合,有了FCN-16s、FCN-8s。实验证明,确实是FCN-8s细节特征最丰富,分割效果最好。
Pytorch搭建FCN网络_第3张图片

代码实现

下面用Pytorch搭建FCN-8s的网络结构。首先是骨干网络,采用VGG16,注意将VGG原来的两个全连接层替换成卷积层conv6和conv7,并保存pool3、pool4、conv7的结果:

from torch import nn
from torchvision.models import vgg16

def vgg_block(num_convs, in_channels, out_channels):
    blk = []
    for i in range(num_convs):
        if i == 0:
            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        else:
            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.ReLU(inplace=True))
    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))  # 宽高减半
    return blk

class VGG16(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG16, self).__init__()
        features = []
        features.extend(vgg_block(2, 3, 64))
        features.extend(vgg_block(2, 64, 128))
        features.extend(vgg_block(3, 128, 256))
        self.index_pool3 = len(features)
        features.extend(vgg_block(3, 256, 512))
        self.index_pool4 = len(features)
        features.extend(vgg_block(3, 512, 512))
        self.features = nn.Sequential(*features)

        self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)

        # load pretrained params from torchvision.models.vgg16(pretrained=True)
        if pretrained:
            pretrained_model = vgg16(pretrained=pretrained)
            pretrained_params = pretrained_model.state_dict()
            keys = list(pretrained_params.keys())
            new_dict = {}
            for index, key in enumerate(self.features.state_dict().keys()):
                new_dict[key] = pretrained_params[keys[index]]
            self.features.load_state_dict(new_dict)

    def forward(self, x):
        pool3 = self.features[:self.index_pool3](x)      # 1/8
        pool4 = self.features[self.index_pool3:self.index_pool4](pool3)  # 1/16
        pool5 = self.features[self.index_pool4:](pool4)  # 1/32

        conv6 = self.relu(self.conv6(pool5))  # 1/32
        conv7 = self.relu(self.conv7(conv6))  # 1/32

        return pool3, pool4, conv7

然后是FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,再上采样8倍。

class FCN(nn.Module):
    def __init__(self, num_classes, backbone='vgg'):
        super(FCN, self).__init__()
        if backbone == 'vgg':
            self.features = VGG16()

        self.scores1 = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.scores2 = nn.Conv2d(512, num_classes, kernel_size=1)
        self.scores3 = nn.Conv2d(256, num_classes, kernel_size=1)

        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8)
        self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=4)
        self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2)

    def forward(self, x):
        pool3, pool4, conv7 = self.features(x)

        conv7 = self.relu(self.scores1(conv7))  # 1×1卷积将通道数映射为类别数

        pool4 = self.relu(self.scores2(pool4))  # 1×1卷积将通道数映射为类别数

        pool3 = self.relu(self.scores3(pool3))  # 1×1卷积将通道数映射为类别数

        s = pool3 + self.upsample_2x(pool4) + self.upsample_4x(conv7)  # 相加融合
        out_8s = self.upsample_8x(s)  # 8倍上采样

        return out_8s

打印一下网络结构:

net = FCN(num_classes=21)
from torchsummary import summary
summary(net.cuda(), (3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Conv2d-32           [-1, 4096, 7, 7]       2,101,248
             ReLU-33           [-1, 4096, 7, 7]               0
           Conv2d-34           [-1, 4096, 7, 7]      16,781,312
             ReLU-35           [-1, 4096, 7, 7]               0
            VGG16-36  [[-1, 256, 28, 28], [-1, 512, 14, 14], [-1, 4096, 7, 7]]               0
           Conv2d-37             [-1, 21, 7, 7]          86,037
             ReLU-38             [-1, 21, 7, 7]               0
           Conv2d-39           [-1, 21, 14, 14]          10,773
             ReLU-40           [-1, 21, 14, 14]               0
           Conv2d-41           [-1, 21, 28, 28]           5,397
             ReLU-42           [-1, 21, 28, 28]               0
  ConvTranspose2d-43           [-1, 21, 28, 28]           1,785
  ConvTranspose2d-44           [-1, 21, 28, 28]           7,077
  ConvTranspose2d-45         [-1, 21, 224, 224]          28,245
================================================================
Total params: 33,736,562
Trainable params: 33,736,562
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 233.14
Params size (MB): 128.69
Estimated Total Size (MB): 362.41
----------------------------------------------------------------

你可能感兴趣的:(语义分割)