Pytorch深度学习实战教程(二):UNet语义分割网络
Pytorch深度学习实战教程(三):UNet模型训练,深度解析!
深度学习图像分割之UNET
U-Net论文学习笔记
U-net论文笔记
深入理解深度学习分割网络Unet——U-Net: Convolutional Networks for Biomedical Image Segmentation
U-Net: Convolutional Networks for Biomedical Image Segmentation
U-Net:用于生物医学图像分割的卷积网络
本论文提出的语义分割模型在FCN基础上进行改进,只需很少的样本,就得到更高的分割精度。相比与FCN,本文提出的模型在上采样部分的卷积层有更多的通道,这可以将context信息传播到更高分辨率的层。本文的提出的模型中,用于特征提取的收缩网络,和用于还原分辨率的扩张网络,呈现对称的U型架构,因此称为U-net。
Mask = Function(I)
非常直观的“U”形网络结构,分别是进行了四次下采样(图中红色向下的箭头),文中将这部分称为“contracting path”,和四次上采样(图中绿色向上的箭头),在文中这部分称为与“contracting path”近似对称的“expansive path”。U-Net是端到端的训练,同时作者在论文最初也提到该结构训练速度快。
U-net架构:每个蓝框是一个feature map,它上面的数字表示该feature map的通道数(channel),feature map的左下角是它的尺寸。白色方框是从contracting path裁剪复制到expanding path上的feature map。
箭头表示不同的操作:
第一个蓝紫的箭头表示3*3的无填充valid卷积,一次卷积操作后特征图尺寸减小2(例第一次卷积操作后尺寸由572*572输出570*570),如下图,每次卷积后再接RELU激活;
第二个灰色箭头就是复制和裁剪啦,需要裁剪的原因是卷积过程中会有边界像素的丢失;
第三个红色箭头表示大小2*2,步长为2的最大池化操作,即图像分称不重叠的2*2大小的小框,每个小框中保留最大的值,经过最大池化操作后特征图尺寸数变为一半,如下图;需要注意的是为了保证输出分割图的无缝衔接,需要确保输入图像的尺寸(x-size和y-size)必须是偶数;
第四个绿色箭头表示2*2的 up-conv(步长1/2,这里我认为可以等价理解为在图像每个像素之前padding一个0值,然后卷积的步长改为1),这个卷积也称为转置卷积(也有人叫反卷积或逆卷积),示意图如下(示意图中为3*3的卷积核,仅为大家作参考),上卷积的输出channels设置为原先的一半【注意,输出通道数是由卷积核个数决定的,可以人为设定,于是这里我用了“设置”一词】,输出结果再与对应的下采样(U形左侧)的特征图裁剪后串联,得到与上卷积操作前特征图相同的channels数;
第五个蓝绿色箭头是一个1*1的卷积,输出一个两通道的特征图【同理,这里的两通道也是可以人为设置的,只需引用两个1*1卷积即可得到2通道的输出特征图】;关于作者这里为什么要输出2通道的特征图呢?因为该分割任务作者需要将数据区分为前景和背景两类,如果是多分类任务,也可以输出多通道特征图。
ps:我偶尔在分割任务中使用一通道输出,个人认为效果同二通道输出,因为二通道某一点位的输出在不同通道上的和为1,而一通道上的值仅是二通道中一个通道上的值,二分类任务中另一个通道上的值可以通过‘概率求和等于1’推导出来。害,说的有点儿绕口,不知道说清了没,如果大家有不同想法可以提出我们共同讨论呀!
网络共有23个卷积层(即所有蓝紫色+绿色+最后一个蓝绿色的箭头表示的过程)。
1、UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层连接。
2、DoubleConv模块:
先看下连续两次的卷积操作。
import torch.nn as nn
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
解释下,上述的Pytorch代码:torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU。
DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。
如上图所示的网络,in_channels设为1,out_channels为64。
输入图片大小为572*572,经过步长为1,padding为0的3*3卷积,得到570*570的feature map,再经过一次卷积得到568*568的feature map。
计算公式:O=(H−F+2×P)/S+1
H为输入feature map的大小,O为输出feature map的大小,F为卷积核的大小,P为padding的大小,S为步长。
3、Down模块:
UNet网络一共有4次下采样过程,模块化代码如下:
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
这里的代码很简单,就是一个maxpool池化层,进行下采样,然后接一个DoubleConv模块。
至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程。
4、Up模块:
上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
代码复杂一些,我们可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值和反卷积。
双线性插值很好理解,示意图:
熟悉双线性插值的朋友对于这幅图应该不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。
对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。
反卷积,顾名思义,就是反着卷积。卷积是让featuer map越来越小,反卷积就是让feature map越来越大,示意图:
下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。
这个示意图,就是一个从2*2的feature map->4*4的feature map过程。
在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat。
5、OutConv模块:
用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主体网络结构了。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:
操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)。
虽然这个操作很简单,也就调用一次,为了美观整洁,也封装一下吧。
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_parts.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:
""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""
import torch.nn.functional as F
from unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up(1024, 512, bilinear)
self.up2 = Up(512, 256, bilinear)
self.up3 = Up(256, 128, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if __name__ == '__main__':
net = UNet(n_channels=3, n_classes=1)
print(net)
使用命令python unet_model.py,如果没有错误,你会得到如下结果:
UNet(
(inc): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
(down1): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down2): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down3): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(down4): Down(
(maxpool_conv): Sequential(
(0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
)
(up1): Up(
(up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up2): Up(
(up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up3): Up(
(up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(up4): Up(
(up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
(conv): DoubleConv(
(double_conv): Sequential(
(0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
)
)
(outc): OutConv(
(conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
)
)
Pytorch深度学习实战教程(三):UNet模型训练,深度解析!