【医学图像分割网络】之Res U-Net网络PyTorch复现

【医学图像分割网络】之Res U-Net网络PyTorch复现

1.内容

U-Net网络算是医学图像分割领域的开山之作,我接触深度学习到现在大概将近大半年时间,看到了很多基于U-Net网络的变体,后续也会继续和大家一起分享学习。这次分享ResNet+U-Net的一个改进版,ResNet本身就是一个十分优秀的backbone,目前应用仍然十分广泛,我们融合进ResNet后,我们就可以使得U-Net进行迁移学习了。
1.将Resnet作为encoder替换U-Net原始结构
2.U-Net提出时间较早,当时还没有例如resnet等网络结构和大规模预训练权重可用
3.U-Net下采样的设计与诸多(如今)成熟的网络结构异曲同工
特征图每一层降低尺寸/2
特征图每一层channels数翻倍x2
4.成熟的网络结构、ImageNet预训练权重可以用来finetuning我们的U-Net,从而起到优化U-Net网络的效果
【医学图像分割网络】之Res U-Net网络PyTorch复现_第1张图片

2.代码

"""
ResNet34 + U-Net
"""
import torch
from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary


class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, d, e=None):
        d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
        # concat

        if e is not None:
            cat = torch.cat([e, d], dim=1)
            out = self.block(cat)
        else:
            out = self.block(d)
        return out


def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
    )
    return block


class Resnet34_Unet(nn.Module):

    def __init__(self, in_channel, out_channel, pretrained=False):
        super(Resnet34_Unet, self).__init__()

        self.resnet = models.resnet34(pretrained=pretrained)
        self.layer0 = nn.Sequential(
            self.resnet.conv1,
            self.resnet.bn1,
            self.resnet.relu,
            self.resnet.maxpool
        )

        # Encode
        self.layer1 = self.resnet.layer1
        self.layer2 = self.resnet.layer2
        self.layer3 = self.resnet.layer3
        self.layer4 = self.resnet.layer4

        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        )

        # Decode
        self.conv_decode4 = expansive_block(1024+512, 512, 512)
        self.conv_decode3 = expansive_block(512+256, 256, 256)
        self.conv_decode2 = expansive_block(256+128, 128, 128)
        self.conv_decode1 = expansive_block(128+64, 64, 64)
        self.conv_decode0 = expansive_block(64, 32, 32)
        self.final_layer = final_block(32, out_channel)

    def forward(self, x):
        x = self.layer0(x)
        # Encode
        encode_block1 = self.layer1(x)
        encode_block2 = self.layer2(encode_block1)
        encode_block3 = self.layer3(encode_block2)
        encode_block4 = self.layer4(encode_block3)

        # Bottleneck
        bottleneck = self.bottleneck(encode_block4)

        # Decode
        decode_block4 = self.conv_decode4(bottleneck, encode_block4)
        decode_block3 = self.conv_decode3(decode_block4, encode_block3)
        decode_block2 = self.conv_decode2(decode_block3, encode_block2)
        decode_block1 = self.conv_decode1(decode_block2, encode_block1)
        decode_block0 = self.conv_decode0(decode_block1)

        final_layer = self.final_layer(decode_block0)

        return final_layer


flag = 0

if flag:
    image = torch.rand(1, 3, 572, 572)
    Resnet34_Unet = Resnet34_Unet(in_channel=3, out_channel=1)
    mask = Resnet34_Unet(image)
    print(mask.shape)

# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Resnet34_Unet(in_channel=1, out_channel=1, pretrained=True).to(device)
summary(model, input_size=(3, 512, 512))

你可能感兴趣的:(医学影像)