U-net介绍
今天来介绍一个经典的语义分割网络U-net, 它于2015年提出,最初应用在医疗影像分割任务上,由于效果很好,之后被广泛应用在各种分割任务中。至今已衍生出许多基于U-net的分割模型。
U-net是典型的Encoder-Decoder结构,encoder进行特征提取,decoder
进行上采样。由于数据的限制,U-net在训练阶段使用了大量的数据增强操作,最后得到了不错的效果。
U-net网络结构
U-net的网络结构如下所示。左边为encoder部分,对输入进行下采样,下采样通过最大池化实现;右边为decoder部分,对encoder的输出进行上采样,恢复分辨率,上采样通过Upsample实现;中间为跳跃连接(Skip-connect),进行特征融合。由于整个网络形似一个"U",所以称为U-net。
网络中除了最后的输出层,其余所有卷积层均为3 * 3卷积。
U-net代码实现
import torch as t
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.dconv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
# inplace设为True可以节省显存/内存
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, img):
return self.dconv(img)
# 下采样
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__()
self.down = nn.Sequential(
nn.MaxPool2d(2, 2),
DoubleConv(in_channels, out_channels)
)
def forward(self, img):
return self.down(img)
# 上采样
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
# ConvTranspose2D 有可学习的参数, 会在训练过程中不断调整参数。会增加模型的复杂度,可能会造成过拟合
# Upsample 没有可学习的参数
# 和Conv2d和MaxPooling2d的区别一样
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# pading 保证x1和x2的大小一样
dx = x2.shape[3] - x1.shape[3]
dy = x2.shape[2] - x1.shape[2]
x1 = nn.functional.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2])
# 通道合并
x = t.cat([x1, x2], dim=1)
return self.conv(x)
# 主网络
class CrackUnet(nn.Module):
def __init__(self, channels, classes, bilinear=True):
super(CrackUnet, self).__init__()
self.channels = channels
self.classes = classes
self.bilinear = bilinear
#
self.inconv = DoubleConv(self.channels, 64)
# 4个下采样层
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
# 4个上采样层, 采用双线性采样
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outconv = nn.Conv2d(64, channels, 1)
def forward(self, img):
img = self.inconv(img)
down1 = self.down1(img)
down2 = self.down2(down1)
down3 = self.down3(down2)
down4 = self.down4(down3)
x = self.up1(down4, down3)
del down4
del down3
x = self.up2(x, down2)
del down2
x = self.up3(x, down1)
del down1
x = self.up5(x, img)
del img
return self.outconv(x)
总结
U-net结构简单稳定,是典型的下采样+上采样的分割网络结构。尤其在数据集较小的时候,推荐使用。