unet 是15年提出的用于解决医学图像分割问题。unet有两部分组成。左边部分可以看出是特征提取网络,用于提取图像的抽象特征。右边可以看作是特征融合操作。与传统的FCN相比,unet使用是使用特征拼接实现特征的融合。unet 通过特征融合操作,实现了浅层的低分辨率(越底层的信息含有越多的细节信息)和深层的高分辨率信息(深层信息含有更多的抽象特征)的融合,充分了利用了图像的上下文信息,使用对称的U型结构使得特征融合的更加彻底。
上图是unet 的网络结构图。其中蓝色方框代表的是特征图。可以看到,左边部分首先进行两层卷积然后进行下采样来提取特征。右边,通过上采样操作后与相应的左边的特征图进行拼接操作。
from torch import nn
import torch
from torch.nn import functional as F
class Conv_Block(nn.Module): # 卷积
def __init__(self, in_channel, out_channel):
super(Conv_Block, self).__init__()
self.layer = nn.Sequential(
####填充的方式,填充的大小,padding_mode 设置填充的方式 ###这里卷积图片的大小没有发生改变
nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect',
bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect',
bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.ReLU()
)
def forward(self, x):
return self.layer(x)
class DownSample(nn.Module): # 下采样 使用卷积步长为2进行下采样
def __init__(self, channel):
super(DownSample, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect',
bias=False),
nn.BatchNorm2d(channel),
nn.LeakyReLU()
) ###下采样 通道不变,图像大小减半
def forward(self, x):
return self.layer(x)
class UpSample(nn.Module): # 上采样(最邻近插值法)
def __init__(self, channel):
super(UpSample, self).__init__()
self.layer = nn.Conv2d(channel, channel // 2, 1, 1) ###上采样 这里首先运用1*1卷积进行降维
def forward(self, x, feature_map):
up = F.interpolate(x, scale_factor=2, mode='nearest') ###上采样插值
out = self.layer(up)
return torch.cat((out, feature_map), dim=1) ###s上采样 首先将x上采样,通道减半
###所以上采样,图像大小增加,通道减半,下采样,图像大小减半,通道增加
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.c1 = Conv_Block(3, 64)
self.d1 = DownSample(64)
self.c2 = Conv_Block(64, 128) ###通道增加
self.d2 = DownSample(128) ##下采样
self.c3 = Conv_Block(128, 256)
self.d3 = DownSample(256)
self.c4 = Conv_Block(256, 512)
self.d4 = DownSample(512)
self.c5 = Conv_Block(512, 1024)
self.u1 = UpSample(1024)
self.c6 = Conv_Block(1024, 512)
self.u2 = UpSample(512)
self.c7 = Conv_Block(512, 256)
self.u3 = UpSample(256)
self.c8 = Conv_Block(256, 128)
self.u4 = UpSample(128)
self.c9 = Conv_Block(128, 64)
self.out = nn.Conv2d(64, 3, 3, 1, 1)
self.Th = nn.Sigmoid()
def forward(self, x):
R1 = self.c1(x) ###通道数64
R2 = self.c2(self.d1(R1)) ### 下采样 图片大小减半,通道数增加 128
R3 = self.c3(self.d2(R2)) ### 下采样 256
R4 = self.c4(self.d3(R3)) ### 512unet 经过四次上采样, 四次下采样,得到五个不同分辨率的图像
R5 = self.c5(self.d4(R4)) ### 1024
O1 = self.c6(self.u1(R5, R4)) ###首先将R5上采样然后与R4进行特征融合 512
O2 = self.c7(self.u2(O1, R3)) ## 256
O3 = self.c8(self.u3(O2, R2))
O4 = self.c9(self.u4(O3, R1))
return self.Th(self.out(O4))
if __name__ == "__main__":
x = torch.randn(2, 3, 256, 256)
net = UNet()
print(net(x).shape)