UNet++网络复现,包括深度监督

注意

代码复现的时候,遵循从左下到右上的顺序,这样思路就会更清楚。UNet++原论文的图解给的详细信息不多,建议先将UNet复现之后,UNet++就很容易上手了。

UNet++网络复现,包括深度监督_第1张图片

代码:

# coding:utf8
from modulefinder import Module

import torch
from torch import nn


class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, pre_BachNorm=False):
        super(Conv, self).__init__()
        if pre_BachNorm:
            self.conv = nn.Sequential(
                nn.BatchNorm2d(in_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(out_channels),
                nn.SiLU(),
            )

    def forward(self, x):
        return self.conv(x)


# 下采样
class Down_Conv(nn.Module):
    def __init__(self, channels):
        super(Down_Conv, self).__init__()
        # self.down_conv = nn.Sequential(
        #     # 原始的只有一个Maxpool,可以在maxpool后加一个卷积层,进行特征融合
        #     nn.MaxPool2d(2),
        #     nn.Conv2d(channels),
        #     nn.BatchNorm2d(channels),
        #     nn.SiLU(inplace=True)
        # )

        # 为了融合更多的信息,我觉得还是卷积比较好
        self.down_conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(channels),
            nn.SiLU(),
        )

    def forward(self, x):
        return self.down_conv(x)


# 上采样
class Up_Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up_Conv, self).__init__()
        self.up = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x):
        return self.up(x)


class UnetPulsPuls(nn.Module):
    def __init__(self, supervised):
        super(UnetPulsPuls, self).__init__()

        self.supervised = supervised

        self.stage1 = Conv(3, 64, pre_BachNorm=True)
        self.stage1_down = Down_Conv(64)
        self.stage2 = Conv(64, 128, True)
        self.stage2_up = Up_Conv(128, 64)
        self.stage2_down = Down_Conv(128)
        self.stage3 = Conv(128, 256, True)
        self.stage3_up = Up_Conv(256, 128)
        self.stage3_down = Down_Conv(256)
        self.stage4 = Conv(256, 512, True)
        self.stage4_up = Up_Conv(512, 256)
        self.stage4_down = Down_Conv(512)
        self.stage5 = Conv(512, 1024, True)
        self.stage5_up = Up_Conv(1024, 512)

        self.x_0_1 = Conv(64 * 2, 64)
        self.x_0_2 = Conv(64 * 3, 64)
        self.x_0_3 = Conv(64 * 4, 64)
        self.x_0_4 = Conv(64 * 5, 64)

        self.x_1_1 = Conv(128 * 2, 128)
        self.x_1_1_up = Up_Conv(128, 64)
        self.x_1_2 = Conv(128 * 3, 128)
        self.x_1_2_up = Up_Conv(128, 64)
        self.x_1_3 = Conv(128 * 4, 128)
        self.x_1_3_up = Up_Conv(128, 64)

        self.x_2_1 = Conv(256 * 2, 256)
        self.x_2_1_up = Up_Conv(256, 128)
        self.x_2_2 = Conv(256 * 3, 256)
        self.x_2_2_up = Up_Conv(256, 128)

        self.x_3_1 = Conv(512 * 2, 512)
        self.x_3_1_up = Up_Conv(512, 256)

        self.end = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x_0_0 = self.stage1(x)
        x_1_0 = self.stage2(self.stage1_down(x_0_0))
        x_2_0 = self.stage3(self.stage2_down(x_1_0))
        x_3_0 = self.stage4(self.stage3_down(x_2_0))
        x_4_0 = self.stage5(self.stage4_down(x_3_0))


        x_0_1 = self.x_0_1(torch.cat([x_0_0, self.stage2_up(x_1_0)], dim=1))
        x_1_1 = self.x_1_1(torch.cat([x_1_0, self.stage3_up(x_2_0)], dim=1))
        x_2_1 = self.x_2_1(torch.cat([x_2_0, self.stage4_up(x_3_0)], dim=1))
        x_3_1 = self.x_3_1(torch.cat([x_3_0, self.stage5_up(x_4_0)], dim=1))

        x_0_2 = self.x_0_2(torch.cat([x_0_0, x_0_1, self.x_1_1_up(x_1_1)], dim=1))
        x_1_2 = self.x_1_2(torch.cat([x_1_0, x_1_1, self.x_2_1_up(x_2_1)], dim=1))
        x_2_2 = self.x_2_2(torch.cat([x_2_0, x_2_1, self.x_3_1_up(x_3_1)], dim=1))

        x_0_3 = self.x_0_3(torch.cat([x_0_0, x_0_1, x_0_2, self.x_1_2_up(x_1_2)], dim=1))
        x_1_3 = self.x_1_3(torch.cat([x_1_0, x_1_1, x_1_2, self.x_2_2_up(x_2_2)], dim=1))

        x_0_4 = self.x_0_4(torch.cat([x_0_0, x_0_1, x_0_2, x_0_3, self.x_1_3_up(x_1_3)], dim=1))

        if self.supervised:
            return self.end(x_0_1), self.end(x_0_2), self.end(x_0_3), self.end(x_0_4)
        else:
            return self.end(x_0_4)


if __name__ == '__main__':
    xx = torch.randn((1, 3, 640, 640))
    mask = torch.rand(1, 3, 640, 640)

    model = UnetPulsPuls(supervised=True)

    # for name, layer in model.named_children():
    #     xx = layer(xx)
    #     print(name, xx.shape)

    x_0_1, x_0_2, x_0_3, x_0_4 = model(xx)
    l1 = mask - x_0_1
    l2 = mask - x_0_2
    l3 = mask - x_0_3
    l4 = mask - x_0_4
    l = l1 + l2 + l3 + l4


你可能感兴趣的:(深度学习,人工智能)