U-Net: Convolutional Networks for Biomedical Images Segmentation

paper:  U-Net: Convolutional Networks for Biomedical Image Segmentation

创新点

  1. 提出了U型encoder-decoder的网络结构,通过skip-connection操作更好的融合浅层的位置信息和深层的语义信息。U-Net借鉴FCN采用全卷积的结构,相比于FCN一个重要的改变是在上采样部分也有大量的特征通道,这允许网络将上下文信息传播到更高分辨率的层。
  2. 医疗图像分割的任务,训练数据非常少,作者通过应用弹性形变做了大量的数据增强。
  3. 提出使用加权损失。

U-Net: Convolutional Networks for Biomedical Images Segmentation_第1张图片

 

一些需要注意的实现细节

  1. 原论文实现中没有使用padding,因此输出feature map的分辨率逐渐减小,在下面介绍的mmsegmentation的实现中采用了padding,因此当stride=1时输出特征图的分辨率不变。
  2. FCN中skip-connection融合浅层信息与深层信息是通过add的方式,而U-Net中是通过concatenate的方式.

实现细节解析

以MMSegmentation中unet的实现为例,假设batch_size=4,输入shape为(4, 3, 480, 480)。

Backbone

  • encode阶段共5个stage,每个stage中有一个ConvBlock,ConvBlock由2个Conv-BN-Relu组成。除了第1个stage,后4个stage在ConvBlock前都有1个2x2-s2的maxpool。每个stage的第1个conv的输出通道x2。因此encode阶段每个stage的输出shape分别为(4, 64, 480, 480)、(4, 128, 240, 240)、(4, 256, 120, 120)、(4, 512, 60, 60)、(4, 1024, 30, 30)。
  • decode阶段共4个stage,和encode后4个降采样的stage对应。每个stage分为upsample、concatenate、conv三个步骤。upsample由一个scale_factor=2的bilinear插值和1个Conv-BN-Relu组成,其中的conv是1x1-s1通道数减半的卷积。第二步concatenate将upsample的输出与encode阶段分辨率大小相同的输出沿通道方向拼接到一起。第三步是一个ConvBlock,和encode阶段一样,这里的ConvBlock也由两个Conv-BN-Relu组成,因为upsample后通道数减半,但和encode对应输出拼接后通道数又还原回去了,这里的ConvBlock中的第一个conv再将输出通道数减半。因此decode阶段每个stage的输出shape分别为(4, 1024, 30, 30)、(4, 512, 60, 60)、(4, 256, 120, 120)、(4, 128 , 240, 240)、(4, 64, 480, 480)。注意decode共4个stage,因此实际的输出是后4个,第一个输出就是encode最后一个stage的输出。

FCN Head

  • backbone中decode阶段的最后一个stage的输出(4, 64, 480, 480)作为head的输入。首先经过一个3x3-s1的conv-bn-relu,通道数不变。然后经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。

Loss

  • loss采用cross-entropy loss

Auxiliary Head

  • backbone中decode阶段的倒数第二个stage的输出(4, 128, 240, 240)作为auxiliary head的输入。经过一个3x3-s1的conv-bn-relu,输出通道数减半为64。经过ratio=0.1的dropout。最后经过一个1x1的conv得到模型最终的输出,输出通道数为类别数(包含背景)。
  • 辅助分支的Loss也是cross-entropy loss,注意这个分支的最终输出分辨率为原始gt的一半,因此在计算loss时需要先通过双线性插值上采样。

模型的完整结构

EncoderDecoder(
  (backbone): UNet(
    (encoder): ModuleList(
      (0): Sequential(
        (0): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (1): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (2): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (3): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (4): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
    )
    (decoder): ModuleList(
      (0): UpConvBlock(
        (conv_block): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
        (upsample): InterpConv(
          (interp_upsample): Sequential(
            (0): Upsample()
            (1): ConvModule(
              (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (1): UpConvBlock(
        (conv_block): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
        (upsample): InterpConv(
          (interp_upsample): Sequential(
            (0): Upsample()
            (1): ConvModule(
              (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): _BatchNormXd(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (2): UpConvBlock(
        (conv_block): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
        (upsample): InterpConv(
          (interp_upsample): Sequential(
            (0): Upsample()
            (1): ConvModule(
              (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): _BatchNormXd(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
      (3): UpConvBlock(
        (conv_block): BasicConvBlock(
          (convs): Sequential(
            (0): ConvModule(
              (conv): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
            (1): ConvModule(
              (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
        (upsample): InterpConv(
          (interp_upsample): Sequential(
            (0): Upsample()
            (1): ConvModule(
              (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): _BatchNormXd(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (activate): ReLU(inplace=True)
            )
          )
        )
      )
    )
  )
  init_cfg=[{'type': 'Kaiming', 'layer': 'Conv2d'}, {'type': 'Constant', 'val': 1, 'layer': ['_BatchNorm', 'GroupNorm']}]
  (decode_head): FCNHead(
    input_transform=None, ignore_index=255, align_corners=False
    (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
    (conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    (dropout): Dropout2d(p=0.1, inplace=False)
    (convs): Sequential(
      (0): ConvModule(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
  )
  init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
  (auxiliary_head): FCNHead(
    input_transform=None, ignore_index=255, align_corners=False
    (loss_decode): CrossEntropyLoss(avg_non_ignore=False)
    (conv_seg): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    (dropout): Dropout2d(p=0.1, inplace=False)
    (convs): Sequential(
      (0): ConvModule(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): _BatchNormXd(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activate): ReLU(inplace=True)
      )
    )
  )
  init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}
)

你可能感兴趣的:(语义分割,人工智能,深度学习,计算机视觉,cnn,pytorch)