代码复现的时候,遵循从左下到右上的顺序,这样思路就会更清楚。UNet++原论文的图解给的详细信息不多,建议先将UNet复现之后,UNet++就很容易上手了。
# coding:utf8
from modulefinder import Module
import torch
from torch import nn
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, pre_BachNorm=False):
super(Conv, self).__init__()
if pre_BachNorm:
self.conv = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
)
else:
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels),
nn.SiLU(),
)
def forward(self, x):
return self.conv(x)
# 下采样
class Down_Conv(nn.Module):
def __init__(self, channels):
super(Down_Conv, self).__init__()
# self.down_conv = nn.Sequential(
# # 原始的只有一个Maxpool,可以在maxpool后加一个卷积层,进行特征融合
# nn.MaxPool2d(2),
# nn.Conv2d(channels),
# nn.BatchNorm2d(channels),
# nn.SiLU(inplace=True)
# )
# 为了融合更多的信息,我觉得还是卷积比较好
self.down_conv = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(channels),
nn.SiLU(),
)
def forward(self, x):
return self.down_conv(x)
# 上采样
class Up_Conv(nn.Module):
def __init__(self, in_channels, out_channels):
super(Up_Conv, self).__init__()
self.up = nn.Sequential(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(out_channels),
nn.SiLU()
)
def forward(self, x):
return self.up(x)
class UnetPulsPuls(nn.Module):
def __init__(self, supervised):
super(UnetPulsPuls, self).__init__()
self.supervised = supervised
self.stage1 = Conv(3, 64, pre_BachNorm=True)
self.stage1_down = Down_Conv(64)
self.stage2 = Conv(64, 128, True)
self.stage2_up = Up_Conv(128, 64)
self.stage2_down = Down_Conv(128)
self.stage3 = Conv(128, 256, True)
self.stage3_up = Up_Conv(256, 128)
self.stage3_down = Down_Conv(256)
self.stage4 = Conv(256, 512, True)
self.stage4_up = Up_Conv(512, 256)
self.stage4_down = Down_Conv(512)
self.stage5 = Conv(512, 1024, True)
self.stage5_up = Up_Conv(1024, 512)
self.x_0_1 = Conv(64 * 2, 64)
self.x_0_2 = Conv(64 * 3, 64)
self.x_0_3 = Conv(64 * 4, 64)
self.x_0_4 = Conv(64 * 5, 64)
self.x_1_1 = Conv(128 * 2, 128)
self.x_1_1_up = Up_Conv(128, 64)
self.x_1_2 = Conv(128 * 3, 128)
self.x_1_2_up = Up_Conv(128, 64)
self.x_1_3 = Conv(128 * 4, 128)
self.x_1_3_up = Up_Conv(128, 64)
self.x_2_1 = Conv(256 * 2, 256)
self.x_2_1_up = Up_Conv(256, 128)
self.x_2_2 = Conv(256 * 3, 256)
self.x_2_2_up = Up_Conv(256, 128)
self.x_3_1 = Conv(512 * 2, 512)
self.x_3_1_up = Up_Conv(512, 256)
self.end = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
nn.Softmax(dim=1)
)
def forward(self, x):
x_0_0 = self.stage1(x)
x_1_0 = self.stage2(self.stage1_down(x_0_0))
x_2_0 = self.stage3(self.stage2_down(x_1_0))
x_3_0 = self.stage4(self.stage3_down(x_2_0))
x_4_0 = self.stage5(self.stage4_down(x_3_0))
x_0_1 = self.x_0_1(torch.cat([x_0_0, self.stage2_up(x_1_0)], dim=1))
x_1_1 = self.x_1_1(torch.cat([x_1_0, self.stage3_up(x_2_0)], dim=1))
x_2_1 = self.x_2_1(torch.cat([x_2_0, self.stage4_up(x_3_0)], dim=1))
x_3_1 = self.x_3_1(torch.cat([x_3_0, self.stage5_up(x_4_0)], dim=1))
x_0_2 = self.x_0_2(torch.cat([x_0_0, x_0_1, self.x_1_1_up(x_1_1)], dim=1))
x_1_2 = self.x_1_2(torch.cat([x_1_0, x_1_1, self.x_2_1_up(x_2_1)], dim=1))
x_2_2 = self.x_2_2(torch.cat([x_2_0, x_2_1, self.x_3_1_up(x_3_1)], dim=1))
x_0_3 = self.x_0_3(torch.cat([x_0_0, x_0_1, x_0_2, self.x_1_2_up(x_1_2)], dim=1))
x_1_3 = self.x_1_3(torch.cat([x_1_0, x_1_1, x_1_2, self.x_2_2_up(x_2_2)], dim=1))
x_0_4 = self.x_0_4(torch.cat([x_0_0, x_0_1, x_0_2, x_0_3, self.x_1_3_up(x_1_3)], dim=1))
if self.supervised:
return self.end(x_0_1), self.end(x_0_2), self.end(x_0_3), self.end(x_0_4)
else:
return self.end(x_0_4)
if __name__ == '__main__':
xx = torch.randn((1, 3, 640, 640))
mask = torch.rand(1, 3, 640, 640)
model = UnetPulsPuls(supervised=True)
# for name, layer in model.named_children():
# xx = layer(xx)
# print(name, xx.shape)
x_0_1, x_0_2, x_0_3, x_0_4 = model(xx)
l1 = mask - x_0_1
l2 = mask - x_0_2
l3 = mask - x_0_3
l4 = mask - x_0_4
l = l1 + l2 + l3 + l4