resnet_unetpp

对原始unet的改进主要有两方面,第一是卷积块的改进,第二是unet模型结构的改进。
对于卷积块的改进,可以把原来的卷积块换成残差块。
对于模型结构的改进可以在模型结构上多加U。
残差块有5种,resnet18、resnet34、resnet50、resnet101、resnet152,resnet后跟的数字表示卷积层数,前两种残差块类型是basic block,后三种是bottleneck block。
resnet_unetpp_第1张图片
图1.图片来自[1]
如图1所示,左边是一个basic block,右边是一个bottleneck block。basic block 由两组通道数相同、kernel size都是33的卷积和一个绕过这两组卷积的short cut构成。bottleneck block 由kenel_size分别为11、33、11,前两个通道数相同、最后一个通道数是前两个通道数4倍的3组卷积和绕过这3组卷积的short cut构成。
resnet_unetpp_第2张图片
图2.图片来自[1]
图2为5种残差块内部kernel size和通道数的设置。

unet结构的改进方法是unet嵌套,如图3所示:
resnet_unetpp_第3张图片
图3.图片来自[2]
原来的unet下采样n次后上采样n次。unet++在原有的基础上,在第n-1次下采样后接着上采样n-1次,以此类推,直到第1次下采样后接着上采样1次。同时,每次新增加上采样后对应的skip也加上。
unet++主要有4种,如图4所示:
resnet_unetpp_第4张图片

图4.图片来自[2]

resnet结合unet++L2的代码如下:


import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):

    def __init__(self, in_chans, block, num_block, num_classes=100):
        super().__init__()

        self.block = block
        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
            )
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        f1 = self.conv1(x)
        f2 = self.conv2_x(self.pool(f1))
        f3 = self.conv3_x(f2)
        f4 = self.conv4_x(f3)
        f5 = self.conv5_x(f4)
        output = self.avg_pool(f5)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return f1,f2,f3,output

def resnet18(in_chans):
    return ResNet(in_chans, BasicBlock, [2, 2, 2, 2])

def resnet34(in_chans):
    return ResNet(in_chans, BasicBlock, [3, 4, 6, 3])

def resnet50(in_chans):
    return ResNet(in_chans, BottleNeck, [3, 4, 6, 3])

def resnet101(in_chans):
    return ResNet(in_chans, BottleNeck, [3, 4, 23, 3])

def resnet152(in_chans):
    return ResNet(in_chans, BottleNeck, [3, 8, 36, 3])

"""### ResNet_UNetpp"""

class ConvBlock(nn.Module):

  def __init__(self, in_chans, out_chans, stride):
    super(ConvBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, stride=stride, padding=1)
    self.bn1 = nn.BatchNorm2d(out_chans)
    self.relu1 = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(out_chans)
    self.relu2 = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.relu1(self.bn1(self.conv1(x)))
    out = self.relu2(self.bn2(self.conv2(x)))
    return out


class UpConvBlock(nn.Module):

  def __init__(self, in_chans, bridge_chans_list, out_chans):
    super(UpConvBlock, self).__init__()
    self.up = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)
    self.conv_block = BasicBlock(out_chans + sum(bridge_chans_list), out_chans, 1)

  def forward(self, x, bridge_list):
    x = self.up(x)
    x = torch.cat([x] + bridge_list, dim=1)
    out = self.conv_block(x)
    return out

class ResNet_UNetpp(nn.Module):

  def __init__(self, in_chans=1, n_classes=2, backbone=resnet18):
    super(ResNet_UNetpp, self).__init__()
    
    '''
兼容resnet18/34/50/101/152
    '''
    #backbone.block.expansion 
    feat_chans = [64*backbone(in_chans).block.expansion, 128*backbone(in_chans).block.expansion, 256*backbone(in_chans).block.expansion]
    self.conv_x00 = backbone(in_chans).block(in_chans, feat_chans[0]//(backbone(in_chans).block.expansion), 1)
    self.conv_x10 = backbone(in_chans).block(feat_chans[0], feat_chans[1]//(backbone(in_chans).block.expansion), 2)
    self.conv_x20 = backbone(in_chans).block(feat_chans[1], feat_chans[2]//(backbone(in_chans).block.expansion), 2)

 
    self.conv_x01 = UpConvBlock(feat_chans[1], [feat_chans[0]], feat_chans[0])
    self.conv_x11 = UpConvBlock(feat_chans[2], [feat_chans[1]], feat_chans[1])
    self.conv_x02 = UpConvBlock(feat_chans[1], [feat_chans[0], feat_chans[0]], feat_chans[0])
    
    self.cls_conv_x01 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)
    self.cls_conv_x02 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)

  def forward(self, x):
 
    x00 = self.conv_x00(x)
    x10 = self.conv_x10(x00)
    x20 = self.conv_x20(x10)
    x01 = self.conv_x01(x10, [x00])
    x11 = self.conv_x11(x20, [x10])
    x02 = self.conv_x02(x11, [x00, x01])
    out01 = self.cls_conv_x01(x01)
    out02 = self.cls_conv_x02(x02)

  
    print('x00', x00.shape)
    print('x10', x10.shape)
    print('x20', x20.shape)
    print('x01', x01.shape)
    print('x11', x11.shape)
    print('x02', x02.shape)
    print('out01', out01.shape)
    print('out02', out02.shape)

    return out01, out02

x = torch.randn((2, 1, 224, 224), dtype=torch.float32)
model = ResNet_UNetpp(in_chans=1, backbone=resnet50)
y1, y2 = model(x)

结果:
resnet_unetpp_第5张图片

Refferences:
[1]He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
[2]Zhou Z, Siddiquee M M R, Tajbakhsh N, et al. Unet++: A nested u-net architecture for medical image segmentation[M]//Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2018: 3-11.

你可能感兴趣的:(计算机视觉,block,深度学习)