UNet网络解读

UNet解读

  • UNet论文
  • UNet的简介
  • 代码解读
    • DoubleConv模块
    • Down模块
    • Up模块
    • OutConv模块
  • 整个UNet
  • 参考资料

UNet论文

UNet论文地址

UNet的简介

UNet网络解读_第1张图片

  • UNet是一个对称的网络结构,左侧为下采样,右侧为上采样
  • 下采样为encoder,上采样为decoder;
  • 四条灰色的平行线,就是在上采样的过程中,融合下采样过程的特征图的通道,Concat
    • 原理就是:一本大小为10cm10cm的书,厚度为3cm的书本(10103)的A书,和一本大小为10cm10cm,厚度为4cm的B书(10103)
    • 将A书和B书,边缘对齐的摞在一起,这样就可以得到一个大小10107的一摞书了
    • 所以对feature map,一个大小为256*256*64的feature map(w为256,h为256,c为64),和一个大小为256*256*32的feature map进行Concat融合,你就会得到一个大小为256*256*96的feature map
    • 在实际使用中,Concat融合的两个feature map的大小不一定相同,例如25625664的feature map和24024032的feature map进行Concat
      • 两种方法
        • 1.将大的256*256*64的feature map进行裁剪,裁剪为240*240*64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到24024096的feature map。
        • 2.将小的240*240*32的feature map进行padding操作,padding为256*256*32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到25625696的feature map。
      • UNet采用的Concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。(详细看代码Up)

代码解读

  • 组成U-Net的模型块主要有如下几个部分:

    • 1)每个子块内部的两次卷积(Double Convolution)

    • 2)左侧模型块之间的下采样连接,即最大池化(Max pooling)

    • 3)右侧模型块之间的上采样连接(Up sampling)

    • 4)输出层的处理(OutConv)

DoubleConv模块

  • 两次卷积操作:
class DoubleConv(nn.Module):
   # mid_channel是第一次conv的out和第二次conv的输入
   def __init__(self, in_channels, out_channels, mid_channels=None):
       super().__init__()
       if not mid_channels:
           mid_channels = out_channels
       self.double_conv = nn.Sequential(
           # 大小 (高宽 + 2*padding - kernel_size)/stride + 1
           # (572 + 2*0 - 3 )/1 +1 = 570
           # 通道1->64
           nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=0, bias=False),
           # 帮助网络训练, 对输入数据做规范化,称为Covariate shift
           # BatchNorm后是不改变输入的shape
           # num_features: 输入维度,也就是数据的特征维度;
           # eps: 是在分母上加的一个值,是为了防止分母为0的情况,让其能正常计算;
           # affine: 是仿射变化,将,分别初始化为1和0;
           # nn.BatchNorm2d是对channel做归一化处理,也就是对批次内的特征进行归一化
           # 加快收敛,防止梯度爆炸和消失
           nn.BatchNorm2d(mid_channels),
           # inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同,类似于C语言的址传递
           # inplace = False 时,不会修改输入对象的值,而是返回一个新创建的对象,所以打印出对象存储地址不同,类似于C语言的值传递
           nn.ReLU(inplace=True),
           # (570 + 2*0 - 3 )/1 +1 = 568
           # 设置padding=1(原始的是设置为0,会改变)经过卷积后不会改变特征层的大小,这也是现在主流的实现方式
           nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=0, bias=False),
           nn.BatchNorm2d(out_channels),
           nn.ReLU(inplace=True)
       )
   
   def forward(self, x):
       x = self.double_conv(x)
       return x
  • 注意代码里面的卷积是如何计算的( (高宽 + 2*padding - kernel_size)/stride + 1),如U图中的初始:(572 + 2*0 - 3 )/1 +1 = 570

Down模块

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            # 扩大通道64->128
            DoubleConv(in_channels, out_channels)
        )
        

    def forward(self, x):
        return self.maxpool_conv(x)

Up模块

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        # 上采样需要插值,双
        if bilinear:
            # 两倍
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
        	# 反卷积
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        # 算出相差多少
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        # F.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充
        # 最后一维padding,第一个元素代表左边padding的个数,第二个元素代表右边padding的个数
        # input:需要扩充的tensor,可以是图像数据,抑或是特征矩阵数据
        # pad:扩充维度,用于预先定义出某维度上的扩充参数
        # mode:扩充方法,’constant‘, ‘reflect’ or ‘replicate’三种模式,分别表示常量,反射,复制
        # value:扩充时指定补充值,但是value只在mode='constant’有效,
        # 即使用value填充在扩充出的新维度位置,而在’reflect’和’replicate’模式下,value不可赋值
        # https://www.modb.pro/db/227153
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
  • __init__初始化函数里定义的上采样方法以及卷积采用DoubleConv
    • 上采样,定义了两种方法:UpsampleConvTranspose2d,也就是双线性插值反卷积

      • 双线性插值:
        UNet网络解读_第2张图片

      • 简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

      • 对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

    • 反卷积:

      • 就是反着卷积

      • UNet网络解读_第3张图片

      • 下面的蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。

      • 这个示意图,就是一个从2*2的feature map->4*4的feature map过程

      • forward前向传播函数中x1接收的是上采样的数据x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat

OutConv模块

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()  # 和super().__init__()一样
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

整个UNet

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # 最左边
        self.inc = DoubleConv(n_channels, 64)
        # 下采样(里面包括下采样完的两次卷积)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        # 插值方式
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 512 // factor, bilinear)
        self.up3 = Up(256, 256 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits =self.outc(x)
        return logits

参考资料

  • UNet语义分割
  • F.pad的使用

你可能感兴趣的:(Pytorch学习,机器学习,#,论文研读,网络,深度学习)