语义分割(Semantic Segmentation):是图像处理和机器视觉一个重要分支。与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割。语义分割目前在自动驾驶、自动抠图、医疗影像等领域有着比较广泛的应用。——语义分割也是一个分类问题!
Unet可以说是最常用、最简单的一种分割模型了,它简单、高效、易懂、容易构建、可以从小数据集中训练。UNet主要贡献是在U型结构上,该结构可以使它使用更少的训练图片的同时,且分割的准确度也不会差,UNet的网络结构如下图:
unet网络非常的简单,前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构(自编码:标签是数据,编解码结构:标签是分割图),由于网络的整体结构是一个大些的英文字母U,所以叫做U-net。
在当时,Unet相比更早提出的FCN网络,使用拼接来作为特征图的融合方式。
Unet的好处:网络层越深得到的特征图,有着更大的视野域,浅层卷积关注纹理特征,深层网络关注本质的那种特征,所以深层浅层特征都是有各自的意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。
这个结构就是先对图片进行卷积和池化,在Unet论文中是池化4次,比方说一开始的图片224x224的,那么就会变成112x112,56x56,28x28,14x14四个不同尺寸的特征。然后我们对14x14的特征图做上采样或者反卷积,得到28x28的特征图,这个28x28的特征图与之前的28x28的特征图进行通道上的拼接concat,然后再对拼接之后的特征图做卷积和上采样,得到56x56的特征图,再与之前的56x56的特征拼接、卷积,再上采样,经过四次上采样可以得到一个与输入图像尺寸相同的224x224的预测结果。
import torch
import torch.nn as nn
from torch.nn.functional import interpolate
#unet
class CNNlayer(nn.Module):
def __init__(self,c_in,c_Out):
super(CNNlayer, self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(c_in,c_Out,3,1,padding=1,padding_mode="reflect",bias=False),#"reflect"镜像翻转
nn.BatchNorm2d(c_Out),
nn.LeakyReLU(),#原模型为Rule(可能会造成网络退化,负半轴为0)
#LeakyReLU不能防止网络过拟合
nn.Dropout2d(0.3),#随机抑制%30的神经元
#图像分割容易出现过拟合(数据量比较小)
nn.Conv2d(c_Out, c_Out, 3, 1, 1,padding_mode="reflect",bias=False),
nn.BatchNorm2d(c_Out),
nn.LeakyReLU(),
nn.Dropout2d(0.4)
)
def forward(self,x):
return self.layer(x)
#下采样(使用最大池化)---降噪能力较强
class DownSampling(nn.Module):
def __init__(self):
super(DownSampling, self).__init__()
self.layer=nn.Sequential(
nn.MaxPool2d(2)
)
def forward(self,x):
return self.layer(x)
# #2:使用步长为2的卷积做下采样
# class DownSampling(nn.Module):
# def __init__(self,C):
# super(DownSampling, self).__init__()
# self.layer=nn.Sequential(
# nn.Conv2d(C,C,3,2,1,padding_mode="reflect"),
# nn.LeakyReLU(),
# nn.BatchNorm2d(C)
# )
# def forward(self,x):
# return self.layer(x)
#上采样
class UpSampling(nn.Module):
def __init__(self,c):
super(UpSampling, self).__init__()
#特征图大小扩大两倍,通道数减半
self.layer=nn.Sequential(
nn.Conv2d(c, c //2, 3, 1, 1, padding_mode="reflect",bias=False),
nn.BatchNorm2d(c//2),
nn.LeakyReLU(),
)
def forward(self,x,r):
#使用临近插值法进行上采样
up=interpolate(x,scale_factor=2,mode="nearest")#特征图放大两倍
x=self.layer(up)#通道数减半
#通道拼接(cat)
return torch.cat((x,r),dim=1)#nchw
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.C1=CNNlayer(3,64)
#4次下采样
self.D1=DownSampling()
self.C2=CNNlayer(64,128)
self.D2 = DownSampling()
self.C3 = CNNlayer(128, 256)
self.D3 = DownSampling()
self.C4 = CNNlayer(256, 512)
self.D4 = DownSampling()
self.C5 = CNNlayer(512,1024)
#4次上采用
self.U1=UpSampling(1024)
self.C6=CNNlayer(1024,512)
self.U2 = UpSampling(512)
self.C7 = CNNlayer(512, 256)
self.U3 = UpSampling(256)
self.C8 = CNNlayer(256, 128)
self.U4 = UpSampling(128)
self.C9 = CNNlayer(128, 64)
self.pre=nn.Conv2d(64,3,3,1,1)#64-->3
def forward(self,x):
#下采样部分
R1=self.C1(x)
R2=self.C2(self.D1(R1))
R3 = self.C3(self.D2(R2))
R4 = self.C4(self.D3(R3))
R5 = self.C5(self.D4(R4))
#上采用部分
O1 = self.C6(self.U1(R5, R4))
O2 = self.C7(self.U2(O1, R3))
O3 = self.C8(self.U3(O2, R2))
O4 = self.C9(self.U4(O3, R1))
return self.pre(O4)
if __name__ == '__main__':
x=torch.randn(1,3,256,256)
net=UNet()
out=net(x)
print(out.shape)