动手实现基于pytorch框架的UNet模型

前言

最近在学习CNN 图像分割相关内容,接触到了UNet 网络,UNet是一个很经典的网络,因其结构像字母U得名,对于一般的图像分割有显著的效果。UNet的网络结构是一个U形结构,左半边是Encoder,右半边是Decoder。Encode部分,下采样不断的增大channel,宽高减半,并提取图像的特征,但是丢弃了图像的位置信息。Decoder 上采样,upconvolution,融合下采样的图像特征并恢复图像的位置信息

  • UNet 结构图
    动手实现基于pytorch框架的UNet模型_第1张图片
  • 关于跟详细的实现内容可以阅读 UNet论文

1.首先,图中的灰色箭头(copy and crop)目的是将浅层特征与深层特征融合,这样可以既保留浅层特征图中较高精度的特征信息,也可以利用深层特征图中抽象的语义信息。
2.其次,在下采样过程中,特征图缩小的尺度是上一层的一半;而在上采样过程中特征图变为上一层的一倍。通道数量变化相反。
3.下采样的时候卷积层特征尺度变化小,原论文使用max pooling进行尺度缩小;上采样也一样,使用upsampling+conv进行尺度增大。

  • 考虑到max pooling会丢失位置信息,决定使用卷积代替它;
  • 使用转置卷积替代简单的上采样(插值),这样既能实现同样的效果,也能加深网络。

但以上做法有利有弊端,如果网络够深,使用卷积是计算速度变慢而且可能产生过拟合

下面开始动手实现Unet结构

实现UNet

下采样的卷积实现

对与下采样过程就是

  1. Conv 3x3,BN,Relu
  2. Conv 3x3,BN,Relu
  3. max pool 2x2

也就是如下图的意思
动手实现基于pytorch框架的UNet模型_第2张图片
但是为了下采样过程和上采样过程衔接部分代码好实现,我连接顺序是:

  1. max pool 2x2
  2. Conv 3x3,BN,Relu
  3. Conv 3x3,BN,Relu
    也就是如下图的意思:

动手实现基于pytorch框架的UNet模型_第3张图片

class DoubleConv(nn.Module):
    """convolution->BN->Relu"""
    def __init__(self,in_channels,out_channels,mid_channels = None):
        super().__init__()

        #这部分和论文不一样,多了个中间输入的channel
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels,mid_channels,kernel_size=3,padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
            nn.Conv2d(mid_channels,out_channels,kernel_size=3,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self,x):
        return self.double_conv(x)


#编码器encoder
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels,out_channels,maxPool = True):
        super().__init__()

        maxPool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels,out_channels),
        )
        
        #使用卷积代替maxPoling 下采样,这样不会丢失位置信息
        #但是如果网络太深,会产生过拟合
        down_conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            DoubleConv(in_channels,out_channels),
            
        )

        self.downsample = ( maxPool_conv if maxPool  else  down_conv)

      
    def forward(self,x):
        return self.downsample(x)

上采样过程

融合下采样数据特征和恢复尺寸

  • up-conv
  • Conv 3x3,BN,Relu
  • Conv 3x3,BN,Relu
  • copy and crop
    动手实现基于pytorch框架的UNet模型_第4张图片
#解码器Decoder
class Up(nn.Module):
    """Downscaling with maxpool then double conv"""
    
    def __init__(self,in_channels,out_channels, bilinear=True):
        super().__init__()

        #如果是双线性插值,使用普通卷积来减少通道
        if bilinear:
            self.up = Upsample(scale_factor=2,mode='bilinear',align_corners=True)
            self.conv = DoubleConv(in_channels,out_channels,in_channels // 2)
        else:
            #采用转置卷积代替上采样,out_channel 是in_channels的一半
            self.up = nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size=2,stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self,x1,x2):
        #据论文,需要进行融合还原之前的尺寸
        #x1是上采样获得的特征
        #x2是下采样获得的特征
        x1 = self.up(x1)
        if (x1.size(2) != x2.size(2)) or (x1.size(3) != x2.size(3)):
            #input is CHW
            #这个是解决填充不一致的问题
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
            #print('sizes',x1.size(),x2.size(),diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)
            x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                            diffY // 2, diffY - diffY//2))
            #print("pad x1:",x1.size())
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

最后一层输出层定义(out segmentation map)

class outlayer(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size=1)
    def forward(self,x):
        x =self.conv(x)
        return x

按步骤组织UNet结构

class UNet(nn.Module):
    def __init__(self,n_classes,n_channels = 3,bilinear=True) -> None:
        super().__init__()
        self.n_channels = n_channels
        self.n_channels = n_classes
        self.bilinear = bilinear

        self.start = DoubleConv(n_channels,64)

        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
        self.down3 = Down(256,512)
        #self.down4 = Down(512,1024)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.final_conv = outlayer(64, n_classes)

        # self.up1 = Up(1024, 512,bilinear)
        # self.up2 = Up(512, 256,bilinear)
        # self.up3 = Up(256, 128 ,bilinear)
        # self.up4 = Up(128, 64,bilinear)
        # self.final_conv = outlayer(64, n_classes)

    def forward(self,x):
        x0 = self.start(x) #3-64
        #print(x0.shape)
        x1 = self.down1(x0)#64-128
        #print(f"x1.shape:\n{x1.shape}")
        x2 = self.down2(x1)#128-246
        #print(f"x2.shape:\n{x2.shape}")
        x3 = self.down3(x2)#256-512
        #print(f"x3.shape:\n{x3.shape}")
        x4 = self.down4(x3)#512-1024
        #print(f"x4.shape:\n{x4.shape}")

        x = self.up1(x4, x3)#1024-512
        #print(f"x.shape:\n{x.shape}")
        x = self.up2(x, x2)#512-256
        #print(f"x.shape:\n{x.shape}")
        x = self.up3(x, x1)#256-128
        #print(f"x.shape:\n{x.shape}")
        x = self.up4(x, x0)#128-64
        #print(f"x.shape:\n{x.shape}")
        logits = self.final_conv(x)
        #print(f"logits:\n{logits.shape}")
        return logits

测试UNet结构是否正确

if __name__ == '__main__':
    net = UNet(n_channels=1,n_classes=2,bilinear=False)
    dev = ('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(dev)
    x = torch.randn(1,1,572,572)
    out = net(x).to(dev)
    print(net)
    print(out.shape)
    '''输出结果'''
    cuda:0
C:\Users\bxd\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ..\c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
UNet(
  (start): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)     
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)     
      (5): ReLU()
    )
  )
  (down1): Down(
    (downsample): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
        )
      )
    )
  )
  (down2): Down(
    (downsample): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
        )
      )
    )
  )
  (down3): Down(
    (downsample): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
        )
      )
    )
  )
  (down4): Down(
    (downsample): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU()
        )
      )
    )
  )
  (up1): Up(
    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
  )
  (up2): Up(
    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
  )
  (up3): Up(
    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
  )
  (up4): Up(
    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
  )
  (final_conv): outlayer(
    (conv): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
  )
)
torch.Size([1, 2, 572, 572])

你一定会写UNet

  • 我的UNet的实现代码
  • 终于基于pytorch实现了UNet的模型。如果你还是不够清楚和明白UNet的实现过程,下面我分享几篇优秀的博文,你看完后一定会有深刻的理解,并会写自己的UNet模型。
    UNet的Pytorch实现 https://blog.csdn.net/kobayashi_/article/details/108951993
    Pytorch:Unet网络代码详解 https://cloud.tencent.com/developer/article/1633363
    UNet模型训练,深度解析 https://blog.csdn.net/weixin_43436958/article/details/107384695
    UNet 项目源码 https://github.com/milesial/Pytorch-UNet

你可能感兴趣的:(DeepLearning,pytorch,深度学习,自动驾驶,UNet,segmention)