U-Net详解

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

目录

一、U-Net产生的原因以及简单介绍

二、U-Net网络结构分析

1、U-Net网络结构图

2、U-Net的Encoder(收缩路径)

 3、U-Net的Decoder(扩展路径)

三、U-Net的pytorch版实现


一、U-Net产生的原因以及简单介绍

产生原因及背景:

U-Net 是为了解决生物医学图像分割问题而产生的。因为它的效果很好,所以后来被广泛应用于语义分割的各个方向:比如卫星图像分割等等。

变体:

U-Net是由FCN衍生而来的,都是 Encoder-Decoder 结构,结构比较简单。

想了解FCN可以看博主往期文章全卷积网络FCN详解_tt丫的博客-CSDN博客

优势和解决的问题:

因为在医学方面,样本收集比较困难,数据量难以达到那么多。为了解决这个问题,U-Net应用了图像增强的方法,在数据集有限的情况下获得了不错的精度。


二、U-Net网络结构分析

1、U-Net网络结构图

数字解释:

其中那些长条的上方(64,128等等)都是通道数;像572 * 572这些是尺寸大小;

蓝白相间上的通道数是两份的总和,即白的和蓝的通道数各是那个通道数的一半

(比如下面这个蓝白相间通道数为128,即代表白的64,蓝的也是64)

框框和箭头解释:

蓝/白色框表示 feature map

蓝色箭头表示 3x3 卷积,用于特征提取;

灰色箭头表示 skip-connection(跳跃连接),用于特征融合;

红色箭头表示最大池化,用于降低维度;

绿色箭头表示上采样,用于恢复维度;

天蓝色(emmm这个颜色是这么描述吧)箭头表示 1x1 卷积,用于输出结果。

U-Net详解_第1张图片

2、U-Net的Encoder(收缩路径)

它是由卷积操作和下采样操作组成。

卷积:

文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1,所以由公式n_{\text {out }}=\left[\frac{n_{\text {in }}+2 p-k}{s}\right]+1

n_{\text {out }}=n_{\text {in }}-2

可以看到第1~5层卷积层都分别是由3个3*3卷积组成,每通过一个3*3卷积尺寸都减少2

池化(下采样):

而前4层卷积层都通过最大池化进入下一层,各池化层的核大小均为k=2,填充均为p=0,步长均为s=2,所以n_{\text {out }}=n_{\text {in }}/2

第5层没有 max-pooling,而是直接将得到的 feature map送入 Decoder

 3、U-Net的Decoder(扩展路径)

feature map 经过 Decoder 恢复原始尺寸,该过程由卷积,上采样和跳级结构组成。

上采样:插值法

补充:

上采样一般有两种方法:

(1)反卷积(详见之前的博文FCN的介绍中全卷积网络FCN详解_tt丫的博客-CSDN博客)(2)插值(bilinear 双线性插值较为常见)

(原来的矩阵称为源矩阵;插值后的矩阵是目标矩阵)

举个栗子:我们要把以下2*2的矩阵插值成4*4

1 2
3

4

A

公式一 —— 目标矩阵到源矩阵的坐标映射:

X_{s r c}=\left(X_{d s t}+0.5\right) *\left(\frac{W i d t h_{s r c}}{W i d t h_{d s t}}\right)-0.5

Y_{s r c}=\left(Y_{d s t}+0.5\right) *\left(\frac{\text { Height }_{s r c}}{\text { Height }_{d s t}}\right)-0.5

A的坐标是(0,1)那么由公式得源矩阵坐标为(-0.25,0.25),是小数没事。

为了找到负数坐标点,我们将源矩阵扩展为下面的形式,中间红色的部分为源矩阵。

1 1 2 2
1 1 2 2
3 3 4 4
3 3 4 4

那么(-0.25,0.25)应该在这里面

1 2
1 2

公式二 —— 具体点的值:

f(i + u, j + v) = (1 - u) (1 - v) f(i, j) + (1 - u) v f(i, j + 1) + u (1 - v) f(i + 1, j) + u v f(i + 1, j + 1)

 可得i = -1, u = 0.75, j = 0, v = 0.25;再由这个公式可得A的值为1.25

其他值以此类推

对应代码

nn.Upsample(scale_factor=2, mode='bilinear')

跳级结构

FCN中的跳级结构解释全卷积网络FCN详解_tt丫的博客-CSDN博客

FCN是采用逐点相加的方法,而U-Net采用将特征在channel维度拼接在一起,形成更“厚”的特征,对应caffe的ConcatLayer层,对应tensorflow的tf.concat()。

对应代码

torch.cat([low_layer_features, deep_layer_features], dim=1)

这两种方法都是为了特征融合。


三、U-Net的pytorch版实现

首先先导入所需要的库

import torch
import torch.nn as nn
import torch.nn.functional as F

对Decorder先进行定义

class Decoder(nn.Module):
  def __init__(self, in_channels,out_channels):
    super(Decoder, self).__init__()
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 
    #up-conv 2*2
    self.conv_relu = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )

  def forward(self, high, low):
    x1 = self.up(high)
    offset = x1.size()[2]-low.size()[2]
    padding = 2*[offset//2,offset//2]
    #计算应该填充多少(这里可以是负数)
    x2 = F.pad(low,padding)#这里相当于对低级特征做一个crop操作
    x1 = torch.cat((x1, x2), dim=1)#拼起来
    x1 = self.conv_relu(x1)#卷积走起
    return x1

U-Net整体网络框架

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1,64,3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64,128,3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128,128,3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(128,256,3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256,256,3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(256,512,3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(512,1024,3),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024,1024,3),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        self.decorder4 = Decoder(1024,512)
        self.decorder3 = Decoder(512,256)
        self.decorder2 = Decoder(256,128)
        self.decorder1 = Decoder(128,64)
        self.last = nn.Conv2d(64, n_class, 1)


    def forward(self, input):
        #Encorder
        layer1 = self.layer1(input)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        layer5 = self.layer5(layer4)

        #Decorder
        layer6 = self.decorder4(layer5,layer4)
        layer7 = self.decorder3(layer6,layer3)
        layer8 = self.decorder2(layer7,layer2)
        layer9 = self.decorder1(layer8,layer1)
        out = self.last(layer9)#n_class预测种类数

        return out

欢迎大家在评论区批评指正,谢谢~

你可能感兴趣的:(图像分割,深度学习,计算机视觉,深度学习,人工智能,U-Net,机器学习)