论文传送门:https://arxiv.org/pdf/1505.04597.pdf
U-NET是分割任务中的典型网络。
U-NET模型结构:
模型整体呈“U”形,主要分为三个部分:
①左侧contraction,提取特征,整体结构类似VGG(没有BN层);
②右侧expansion,将特征层上采样至原图片大小,最后通过1x1卷积,输出segmentation map;
③中间的copy and crop操作,多尺度特征图融合。
Overlap-tile策略:
U-NET使用Overlap-tile策略,即在图片输入模型前进行镜像padding,使得模型对于图片边缘的预测也有较高的准确率。输入图片(1x388x388)采用镜像padding后,得到input image tile(1x572x572),模型中所有的卷积操作均不进行padding处理,卷积后特征层尺寸减小,所以模型输出output segmentation map(2x388x388)与输入原图片尺寸相同,且多尺度特征融合时需要进行crop操作。
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNET(nn.Module): # 定义UNET模型结构
def __init__(self, img_shape=(1, 388, 388)): # 初始化方法
super(UNET, self).__init__() # 继承初始化方法
self.img_shape = img_shape # 输入图片形状,默认为(1,388,388)
self.reflectionpad = nn.ReflectionPad2d(92) # 对输入图片进行镜像padding,对应原文的Overlap-tile策略,经计算,padding长度为92,此值与输入图片尺寸无关
self.contraction1 = self.contraction(self.img_shape[0], 64,
maxpool=False) # contraction结构块,(maxpool)+conv+relu+conv+relu,第一个结构块不含maxpool
self.contraction2 = self.contraction(64, 128) # contraction结构块
self.contraction3 = self.contraction(128, 256) # contraction结构块
self.contraction4 = self.contraction(256, 512) # contraction结构块
self.contraction5 = self.contraction(512, 1024) # contraction结构块
self.deconv1 = nn.ConvTranspose2d(1024, 512, 2, 2, 0) # transconv,实现upconv,上采样
self.expansion1 = self.expansion(1024, 512) # expansion结构块,conv+relu+conv+relu+(transconv),第一个结构块前单独进行一次upconv
self.expansion2 = self.expansion(512, 256) # expansion结构块
self.expansion3 = self.expansion(256, 128) # expansion结构块
self.expansion4 = self.expansion(128, 64, upconv=False) # expansion块,最后一个结构块不含upconv
self.conv1 = nn.Conv2d(64, 2, 1, 1, 0) # 1x1卷积,输出segmentation map
def contraction(self, in_channel, out_channel, maxpool=True): # 定义contraction结构块,对应模型左侧
layers = [] # 列表,用于存放模型结构
if maxpool: # 如果进行maxpool
layers += [nn.MaxPool2d(2, 2)] # 添加maxpool
layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0), # 添加conv
nn.ReLU(), # 添加relu
nn.Conv2d(out_channel, out_channel, 3, 1, 0), # 添加conv
nn.ReLU()] # 添加relu
return nn.Sequential(*layers) # 返回contraction结构块,(maxpool)+conv+relu+conv+relu,方便进行copy and crop操作
def expansion(self, in_channel, out_channel, upconv=True): # 定义expansion结构块,对应模型右侧
layers = [] # 列表,用于存放模型结构
layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0), # 添加conv
nn.ReLU(), # 添加relu
nn.Conv2d(out_channel, out_channel, 3, 1, 0), # 添加conv
nn.ReLU()] # 添加relu
if upconv: # 如果进行upconv
layers += [nn.ConvTranspose2d(out_channel, out_channel // 2, 2, 2, 0)] # 添加transconv
return nn.Sequential(*layers) # 返回expansion结构块,conv+relu+conv+relu+(transconv),方便进行copy and crop操作
def crop(self, x, target_x):
'''
crop操作,将左侧特征层(n,c,h,w)裁剪至右侧特征层(n,c,h',w')
:param x: 输入特征
:param target_x:目标特征
:return: 经过裁剪后,与目标特征尺寸相同的输入特征
'''
pad_h = -(x.shape[2] - target_x.shape[2]) // 2 # H维度上裁剪尺寸,为负值
pad_w = -(x.shape[3] - target_x.shape[3]) // 2 # W维度上裁剪尺寸,为负值
return F.pad(x, (pad_h, pad_h, pad_w, pad_w)) # 使用pad操作,输入pad为负值,即实现裁剪操作
def forward(self, x): # 前传函数
x = self.reflectionpad(x) # 镜像padding,(n,1,388,388)-->(n,1,572,572)
x1 = self.contraction1(x) # contraction,(n,1,572,572)-->(n,64,570,570)-->(n,64,568,568)
x2 = self.contraction2(x1) # contraction,(n,64,568,568)-->(n,128,284,284)-->(n,128,282,282)-->(n,128,280,280)
x3 = self.contraction3(x2) # contraction,(n,128,280,280)-->(n,256,140,140)-->(n,256,138,138)-->(n,256,136,136)
x4 = self.contraction4(x3) # contraction,(n,256,136,136)-->(n,512,68,68)-->(n,512,66,66)-->(n,512,64,64)
x = self.contraction5(x4) # contraction,(n,512,64,64)-->(n,1024,32,32)-->(n,1024,30,30)-->(n,1024,28,28)
x = self.deconv1(x) # upconv,(n,1024,28,28)-->(n,512,56,56)
x4 = self.crop(x4, x) # crop,(n,512,64,64)-->(n,512,56,56)
x = torch.cat((x4, x), dim=1) # cat,在C维度进行拼接,(n,512,56,56)+(n,512,56,56)-->(n,1024,56,56)
x = self.expansion1(x) # expasion,(n,1024,56,56)-->(n,512,54,54)-->(n,512,52,52)-->(n,256,104,104)
x3 = self.crop(x3, x) # crop,(n,256,128,128)-->(n,256,104,104)
x = torch.cat((x3, x), dim=1) # cat,在C维度进行拼接,(n,256,104,104)+(n,256,104,104)-->(n,512,104,104)
x = self.expansion2(x) # expasion,(n,512,104,104)-->(n,256,102,102)-->(n,256,100,100)-->(n,128,200,200)
x2 = self.crop(x2, x) # crop,(n,128,280,280)-->(n,128,200,200)
x = torch.cat((x2, x), dim=1) # cat,在C维度进行拼接,(n,128,200,200)+(n,128,200,200)-->(n,256,200,200)
x = self.expansion3(x) # expasion,(n,256,200,200)-->(n,128,198,198)-->(n,128,196,196)-->(n,64,392,392)
x1 = self.crop(x1, x) # crop,(n,64,568,568)-->(n,64,392,392)
x = torch.cat((x1, x), dim=1) # cat,在C维度进行拼接,(n,64,392,392)+(n,64,392,392)-->(n,128,392,392)
x = self.expansion4(x) # expasion,(n,128,392,392)-->(n,64,390,390)-->(n,64,338,338)
x = self.conv1(x) # 1x1conv,(n,64,338,338)-->(n,2,338,338)
return x # 返回与输入图片尺寸相同的segmentation map