U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续也会继续和大家一起分享学习。这次分享ResNet+U-Net的一个改进版,ResNet本身就是一个十分优秀的backbone,目前应用仍然十分广泛,我们融合进ResNet后,我们就可以使得U-Net进行迁移学习了。
1.将Resnet作为encoder替换U-Net原始结构
2.U-Net提出时间较早,当时还没有例如resnet等网络结构和大规模预训练权重可用
3.U-Net下采样的设计与诸多(如今)成熟的网络结构异曲同工
特征图每一层降低尺寸/2
特征图每一层channels数翻倍x2
4.成熟的网络结构、ImageNet预训练权重可以用来finetuning我们的U-Net,从而起到优化U-Net网络的效果
"""
ResNet34 + U-Net
"""
import torch
from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary
class expansive_block(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(expansive_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels)
)
def forward(self, d, e=None):
d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
# concat
if e is not None:
cat = torch.cat([e, d], dim=1)
out = self.block(cat)
else:
out = self.block(d)
return out
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class Resnet34_Unet(nn.Module):
def __init__(self, in_channel, out_channel, pretrained=False):
super(Resnet34_Unet, self).__init__()
self.resnet = models.resnet34(pretrained=pretrained)
self.layer0 = nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
self.resnet.maxpool
)
# Encode
self.layer1 = self.resnet.layer1
self.layer2 = self.resnet.layer2
self.layer3 = self.resnet.layer3
self.layer4 = self.resnet.layer4
# Bottleneck
self.bottleneck = torch.nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
nn.MaxPool2d(kernel_size=(2, 2), stride=2)
)
# Decode
self.conv_decode4 = expansive_block(1024+512, 512, 512)
self.conv_decode3 = expansive_block(512+256, 256, 256)
self.conv_decode2 = expansive_block(256+128, 128, 128)
self.conv_decode1 = expansive_block(128+64, 64, 64)
self.conv_decode0 = expansive_block(64, 32, 32)
self.final_layer = final_block(32, out_channel)
def forward(self, x):
x = self.layer0(x)
# Encode
encode_block1 = self.layer1(x)
encode_block2 = self.layer2(encode_block1)
encode_block3 = self.layer3(encode_block2)
encode_block4 = self.layer4(encode_block3)
# Bottleneck
bottleneck = self.bottleneck(encode_block4)
# Decode
decode_block4 = self.conv_decode4(bottleneck, encode_block4)
decode_block3 = self.conv_decode3(decode_block4, encode_block3)
decode_block2 = self.conv_decode2(decode_block3, encode_block2)
decode_block1 = self.conv_decode1(decode_block2, encode_block1)
decode_block0 = self.conv_decode0(decode_block1)
final_layer = self.final_layer(decode_block0)
return final_layer
flag = 0
if flag:
image = torch.rand(1, 3, 572, 572)
Resnet34_Unet = Resnet34_Unet(in_channel=3, out_channel=1)
mask = Resnet34_Unet(image)
print(mask.shape)
# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Resnet34_Unet(in_channel=1, out_channel=1, pretrained=True).to(device)
summary(model, input_size=(3, 512, 512))