手撸网络结构—UNet

用于分类任务的卷积神经网络输出一般都是一个单一的类标签,但是在很多图像视觉任务中往往要求输出信息包含位置信息,比如给图像的每个像素点赋予类别信息。

这样就需要使用到fully convolutional network 相较于FCN,UNet不只是简单的将卷积网络的结果直接进行转置卷积,而是结合了前面卷积层的中间计算结果,这样既利用了卷积部分的位置信息又利用了最后的分类信息。

UNet网络结构

手撸网络结构—UNet_第1张图片

UNet分为两个阶段:downsample stage 和 upsample stage。

图中的每个蓝色框对应了每层的feature map,上面数字表示当前的通道数。左下角的数字表示feature map的尺寸,这里输入为572x572通道数为1。白色的框表示将对应的卷积层结果进行crop后的结果,白色框与蓝色框进行组合,作为后续网络的输入。

因为在前面卷积的过程中没有添加padding,因此会存在feature map的尺寸变化,所以在上采样过程中,需要对每个阶段的卷积计算结果进行crop操作然后再与转置卷积的结果进行concat(通道维度的相加)。

def concat(tensor1,tensor2):
        # concat 2 tensor by the channel axes
        tensor1,tensor2 = (tensor1,tensor2) if tensor1.size()[3]>=tensor2.size()[3] else (tensor2,tensor1)
        crop_val = int((tensor1.size()[3]-tensor2.size()[3])/2)
        tensor1 = tensor1[:, :, crop_val:tensor1.size()[3]-crop_val
                      , crop_val:tensor1.size()[3]-crop_val]
        return torch.cat((tensor1,tensor2),1)

完整代码:

import torch
import torch.nn as nn
from torchsummary import summary


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution without padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride)

def up_conv2x2(in_planes,out_planes):
    return nn.ConvTranspose2d(in_planes,out_planes,kernel_size=2,stride=2)

def max_pool2x2():
    return nn.MaxPool2d(kernel_size=2,stride=2)


class UNet(nn.Module):
    def __init__(self,class_num=1000):
        super(UNet, self).__init__()
        # downsample stage
        self.conv_1 = nn.Sequential(conv3x3(1,64),conv3x3(64,64))
        self.conv_2 = nn.Sequential(conv3x3(64,128),conv3x3(128,128))
        self.conv_3 = nn.Sequential(conv3x3(128,256),conv3x3(256,256))
        self.conv_4 = nn.Sequential(conv3x3(256,512),conv3x3(512,512))
        self.conv_5 = nn.Sequential(conv3x3(512,1024),conv3x3(1024,1024))
        self.maxpool = max_pool2x2()
        
        # upsample stage
        # up_conv_4 corresponds conv_4
        self.up_conv_4 = nn.Sequential(up_conv2x2(1024,512))
        # conv the cat(stage_4,up_conv_4) from 1024 to 512
        self.conv_6 = nn.Sequential(conv3x3(1024,512),conv3x3(512,512))
        # up_conv_3 corresponds conv_3
        self.up_conv_3 = nn.Sequential(up_conv2x2(512,256))
        # conv the cat(stage_3,up_conv_3) from 512 to 256
        self.conv_7 = nn.Sequential(conv3x3(512,256),conv3x3(256,256))
        # up_conv_2 corresponds conv_2
        self.up_conv_2 = nn.Sequential(up_conv2x2(256,128))
        # conv the cat(stage_2,up_conv_2) from 256 to 128
        self.conv_8 = nn.Sequential(conv3x3(256,128),conv3x3(128,128))
        # up_conv_1 corresponds conv_1
        self.up_conv_1 = nn.Sequential(up_conv2x2(128,64))
        # conv the cat(stage_1,up_conv_1) from 128 to 64
        self.conv_9 = nn.Sequential(conv3x3(128,64),conv3x3(64,64))
        # output
        self.result = conv1x1(64,2)
    
    def _concat(self,tensor1,tensor2):
        # concat 2 tensor by the channel axes
        tensor1,tensor2 = (tensor1,tensor2) if tensor1.size()[3]>=tensor2.size()[3] else (tensor2,tensor1)
        crop_val = int((tensor1.size()[3]-tensor2.size()[3])/2)
        tensor1 = tensor1[:, :, crop_val:tensor1.size()[3]-crop_val
                      , crop_val:tensor1.size()[3]-crop_val]
        return torch.cat((tensor1,tensor2),1)

    def forward(self,x):
        # get 4 stage conv output
        stage_1 = self.conv_1(x)
        stage_2 = self.conv_2(self.maxpool(stage_1))
        stage_3 = self.conv_3(self.maxpool(stage_2))
        stage_4 = self.conv_4(self.maxpool(stage_3))

        # get up_conv_4 and concat with stage_4
        up_in_4 = self.conv_5(self.maxpool(stage_4))
        up_stage_4 = self.up_conv_4(up_in_4)
        up_stage_4 = self._concat(stage_4,up_stage_4)
        # get up_conv_3 and concat with stage_3
        up_in_3 = self.conv_6(up_stage_4)
        up_stage_3 = self.up_conv_3(up_in_3)
        up_stage_3 = self._concat(stage_3,up_stage_3)
        # get up_conv_2 and concat with stage_2
        up_in_2 = self.conv_7(up_stage_3)
        up_stage_2 = self.up_conv_2(up_in_2)
        up_stage_2 = self._concat(stage_2,up_stage_2)
        # get up_conv_1 and concat with stage_1
        up_in_1 = self.conv_8(up_stage_2)
        up_stage_1 = self.up_conv_1(up_in_1)
        up_stage_1 = self._concat(stage_1,up_stage_1)

        # last conv to channel 2
        out = self.conv_9(up_stage_1)
        # result
        out = self.result(out)
        return out


if __name__ == '__main__':
    ut = UNet(12)
    summary(ut,(1,572,572))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 570, 570]             640
            Conv2d-2         [-1, 64, 568, 568]          36,928
         MaxPool2d-3         [-1, 64, 284, 284]               0
            Conv2d-4        [-1, 128, 282, 282]          73,856
            Conv2d-5        [-1, 128, 280, 280]         147,584
         MaxPool2d-6        [-1, 128, 140, 140]               0
            Conv2d-7        [-1, 256, 138, 138]         295,168
            Conv2d-8        [-1, 256, 136, 136]         590,080
         MaxPool2d-9          [-1, 256, 68, 68]               0
           Conv2d-10          [-1, 512, 66, 66]       1,180,160
           Conv2d-11          [-1, 512, 64, 64]       2,359,808
        MaxPool2d-12          [-1, 512, 32, 32]               0
           Conv2d-13         [-1, 1024, 30, 30]       4,719,616
           Conv2d-14         [-1, 1024, 28, 28]       9,438,208
  ConvTranspose2d-15          [-1, 512, 56, 56]       2,097,664
           Conv2d-16          [-1, 512, 54, 54]       4,719,104
           Conv2d-17          [-1, 512, 52, 52]       2,359,808
  ConvTranspose2d-18        [-1, 256, 104, 104]         524,544
           Conv2d-19        [-1, 256, 102, 102]       1,179,904
           Conv2d-20        [-1, 256, 100, 100]         590,080
  ConvTranspose2d-21        [-1, 128, 200, 200]         131,200
           Conv2d-22        [-1, 128, 198, 198]         295,040
           Conv2d-23        [-1, 128, 196, 196]         147,584
  ConvTranspose2d-24         [-1, 64, 392, 392]          32,832
           Conv2d-25         [-1, 64, 390, 390]          73,792
           Conv2d-26         [-1, 64, 388, 388]          36,928
           Conv2d-27          [-1, 2, 388, 388]             130
================================================================
Total params: 31,030,658
Trainable params: 31,030,658
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 1096.59
Params size (MB): 118.37
Estimated Total Size (MB): 1216.21
----------------------------------------------------------------
ckward pass size (MB): 1096.59
Params size (MB): 118.37
Estimated Total Size (MB): 1216.21
----------------------------------------------------------------

你可能感兴趣的:(计算机视觉)