对原始unet的改进主要有两方面,第一是卷积块的改进,第二是unet模型结构的改进。
对于卷积块的改进,可以把原来的卷积块换成残差块。
对于模型结构的改进可以在模型结构上多加U。
残差块有5种,resnet18、resnet34、resnet50、resnet101、resnet152,resnet后跟的数字表示卷积层数,前两种残差块类型是basic block,后三种是bottleneck block。
图1.图片来自[1]
如图1所示,左边是一个basic block,右边是一个bottleneck block。basic block 由两组通道数相同、kernel size都是33的卷积和一个绕过这两组卷积的short cut构成。bottleneck block 由kenel_size分别为11、33、11,前两个通道数相同、最后一个通道数是前两个通道数4倍的3组卷积和绕过这3组卷积的short cut构成。
图2.图片来自[1]
图2为5种残差块内部kernel size和通道数的设置。
unet结构的改进方法是unet嵌套,如图3所示:
图3.图片来自[2]
原来的unet下采样n次后上采样n次。unet++在原有的基础上,在第n-1次下采样后接着上采样n-1次,以此类推,直到第1次下采样后接着上采样1次。同时,每次新增加上采样后对应的skip也加上。
unet++主要有4种,如图4所示:
图4.图片来自[2]
resnet结合unet++L2的代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, in_chans, block, num_block, num_classes=100):
super().__init__()
self.block = block
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
f1 = self.conv1(x)
f2 = self.conv2_x(self.pool(f1))
f3 = self.conv3_x(f2)
f4 = self.conv4_x(f3)
f5 = self.conv5_x(f4)
output = self.avg_pool(f5)
output = output.view(output.size(0), -1)
output = self.fc(output)
return f1,f2,f3,output
def resnet18(in_chans):
return ResNet(in_chans, BasicBlock, [2, 2, 2, 2])
def resnet34(in_chans):
return ResNet(in_chans, BasicBlock, [3, 4, 6, 3])
def resnet50(in_chans):
return ResNet(in_chans, BottleNeck, [3, 4, 6, 3])
def resnet101(in_chans):
return ResNet(in_chans, BottleNeck, [3, 4, 23, 3])
def resnet152(in_chans):
return ResNet(in_chans, BottleNeck, [3, 8, 36, 3])
"""### ResNet_UNetpp"""
class ConvBlock(nn.Module):
def __init__(self, in_chans, out_chans, stride):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_chans)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_chans)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(x)))
return out
class UpConvBlock(nn.Module):
def __init__(self, in_chans, bridge_chans_list, out_chans):
super(UpConvBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)
self.conv_block = BasicBlock(out_chans + sum(bridge_chans_list), out_chans, 1)
def forward(self, x, bridge_list):
x = self.up(x)
x = torch.cat([x] + bridge_list, dim=1)
out = self.conv_block(x)
return out
class ResNet_UNetpp(nn.Module):
def __init__(self, in_chans=1, n_classes=2, backbone=resnet18):
super(ResNet_UNetpp, self).__init__()
'''
兼容resnet18/34/50/101/152
'''
#backbone.block.expansion
feat_chans = [64*backbone(in_chans).block.expansion, 128*backbone(in_chans).block.expansion, 256*backbone(in_chans).block.expansion]
self.conv_x00 = backbone(in_chans).block(in_chans, feat_chans[0]//(backbone(in_chans).block.expansion), 1)
self.conv_x10 = backbone(in_chans).block(feat_chans[0], feat_chans[1]//(backbone(in_chans).block.expansion), 2)
self.conv_x20 = backbone(in_chans).block(feat_chans[1], feat_chans[2]//(backbone(in_chans).block.expansion), 2)
self.conv_x01 = UpConvBlock(feat_chans[1], [feat_chans[0]], feat_chans[0])
self.conv_x11 = UpConvBlock(feat_chans[2], [feat_chans[1]], feat_chans[1])
self.conv_x02 = UpConvBlock(feat_chans[1], [feat_chans[0], feat_chans[0]], feat_chans[0])
self.cls_conv_x01 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)
self.cls_conv_x02 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)
def forward(self, x):
x00 = self.conv_x00(x)
x10 = self.conv_x10(x00)
x20 = self.conv_x20(x10)
x01 = self.conv_x01(x10, [x00])
x11 = self.conv_x11(x20, [x10])
x02 = self.conv_x02(x11, [x00, x01])
out01 = self.cls_conv_x01(x01)
out02 = self.cls_conv_x02(x02)
print('x00', x00.shape)
print('x10', x10.shape)
print('x20', x20.shape)
print('x01', x01.shape)
print('x11', x11.shape)
print('x02', x02.shape)
print('out01', out01.shape)
print('out02', out02.shape)
return out01, out02
x = torch.randn((2, 1, 224, 224), dtype=torch.float32)
model = ResNet_UNetpp(in_chans=1, backbone=resnet50)
y1, y2 = model(x)
Refferences:
[1]He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
[2]Zhou Z, Siddiquee M M R, Tajbakhsh N, et al. Unet++: A nested u-net architecture for medical image segmentation[M]//Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2018: 3-11.