U-Net模型搭建python实现

最近在学习U-Net,读完文章之后尝试搭建模型的框架,阅读了前人的模型后,试着自己搭建了一下,适合初学者。
U-Net模型搭建python实现_第1张图片

import torch
import torch.nn as nn
from torchsummary import summary



#   2次 3*3卷积
#   Remark,第一个3*3卷积承担,升降维的功能
class DoubleConv(nn.Module):

    def __init__(self, in_channels,mid_channels, out_channels):
        super(DoubleConv, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),
            # Same conv. not valid conv. in original paper
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# 定义主网络框架
class UNet(nn.Module):

    def __init__(self, original_channels=1, num_classes=1):
        super(UNet, self).__init__()
        self.original_channels = original_channels  # 输入可能是RGB 3通道,也有可能是灰度图 1通道,所以定义original_channels;
        self.num_classes = num_classes  # 输出可能是二分类(前后景),也可能是多分类,所以定义num_classes

        # Contracting path: 卷积2次DoubleCon,下采样Maxpool 1次,一共编码5次
        self.encoder1 = DoubleConv(self.original_channels,mid_channels=64, out_channels=64)
        self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder2 = DoubleConv(in_channels=64, mid_channels=128,out_channels=128)
        self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder3 = DoubleConv(in_channels=128, mid_channels=256,out_channels=256)
        self.down3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder4 = DoubleConv(in_channels=256, mid_channels=512,out_channels=512)
        self.down4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder5 = DoubleConv(in_channels=512,mid_channels=1024, out_channels=1024)

        # Expansive path: 上采样ConvTranspose 1次,卷积2次DoubleCon,一共解码5次,最后一次为1*1Conv.
        # Remark:通道拼接放在正向传播中做,注意编码和上采样的channels匹配的问题

        self.up1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(in_channels=1024, mid_channels=512,out_channels=512)

        self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(in_channels=512,mid_channels=256, out_channels=256)

        self.up3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(in_channels=256,mid_channels=128, out_channels=128)

        self.up4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(in_channels=128, mid_channels=64,out_channels=64)

        self.decoder5 = nn.Conv2d(64, num_classes, kernel_size=1)


    # 定义正向传播的过程

    def forward(self, x):
        encoder1 = self.encoder1(x)              # original_channel → 64_channels  512*512 → 512*512
        encoder1_pool = self.down1(encoder1)     # 64_channels → 64_channels       512*512 → 256*256
        encoder2 = self.encoder2(encoder1_pool)  # 64_channels → 128_channels      256*256 → 256*256
        encoder2_pool = self.down2(encoder2)     # 128_channels → 128_channels     256*256 → 128*128
        encoder3 = self.encoder3(encoder2_pool)  # 128_channels → 256_channels     128*128 → 128*128
        encoder3_pool = self.down3(encoder3)     # 256_channels → 256_channels     128*128 → 64*64
        encoder4 = self.encoder4(encoder3_pool)  # 256_channels → 512_channels       64*64 → 64*64
        encoder4_pool = self.down4(encoder4)     # 512_channels → 512_channels       64*64 → 32*32
        encoder5 = self.encoder5(encoder4_pool)  # 512_channels → 1024_channels      32*32 → 32*32

        decoder1_up = self.up1(encoder5)         # 1024_channels → 512_channels      32*32 → 64*64
        decoder1 = self.decoder1(torch.cat((encoder4, decoder1_up), dim=1))
                                                 # 512+512_channels → 512_channels   64*64 → 64*64

        decoder2_up = self.up2(decoder1)         # 512_channels → 256_channels       64*64 → 128*128
        decoder2 = self.decoder2(torch.cat((encoder3, decoder2_up), dim=1))
                                                 # 256+256_channels → 256_channels   128*128 → 128*128

        decoder3_up = self.up3(decoder2)         # 256_channels → 128_channels       128*64 → 256*256
        decoder3 = self.decoder3(torch.cat((encoder2, decoder3_up), dim=1))
                                                 # 128+128_channels → 128_channels   256*256 → 256*256
        decoder4_up = self.up4(decoder3)         # 128_channels → 64_channels        256*256 → 256*256
        decoder4 = self.decoder4(torch.cat((encoder1, decoder4_up), dim=1))
                                                 # 64+64_channels → 64_channels      256*256 → 512*512
        out = self.decoder5(decoder4)            # 64_channels → num_classes channels 512*512 → 512*512
        return out

"""
下面三行代码是验证模型架构用,实际不需要
"""
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# unet = UNet().to(device)
# summary(unet, (1, 512, 512))



使用pytorch的summary功能,如果没有可以pip install summary
三行测试代码开启后,可以查看模型的架构如下

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 512, 512]             576
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,864
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
        DoubleConv-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
            Conv2d-9        [-1, 128, 256, 256]          73,728
      BatchNorm2d-10        [-1, 128, 256, 256]             256
             ReLU-11        [-1, 128, 256, 256]               0
           Conv2d-12        [-1, 128, 256, 256]         147,456
      BatchNorm2d-13        [-1, 128, 256, 256]             256
             ReLU-14        [-1, 128, 256, 256]               0
       DoubleConv-15        [-1, 128, 256, 256]               0
        MaxPool2d-16        [-1, 128, 128, 128]               0
           Conv2d-17        [-1, 256, 128, 128]         294,912
      BatchNorm2d-18        [-1, 256, 128, 128]             512
             ReLU-19        [-1, 256, 128, 128]               0
           Conv2d-20        [-1, 256, 128, 128]         589,824
      BatchNorm2d-21        [-1, 256, 128, 128]             512
             ReLU-22        [-1, 256, 128, 128]               0
       DoubleConv-23        [-1, 256, 128, 128]               0
        MaxPool2d-24          [-1, 256, 64, 64]               0
           Conv2d-25          [-1, 512, 64, 64]       1,179,648
      BatchNorm2d-26          [-1, 512, 64, 64]           1,024
             ReLU-27          [-1, 512, 64, 64]               0
           Conv2d-28          [-1, 512, 64, 64]       2,359,296
      BatchNorm2d-29          [-1, 512, 64, 64]           1,024
             ReLU-30          [-1, 512, 64, 64]               0
       DoubleConv-31          [-1, 512, 64, 64]               0
        MaxPool2d-32          [-1, 512, 32, 32]               0
           Conv2d-33         [-1, 1024, 32, 32]       4,718,592
      BatchNorm2d-34         [-1, 1024, 32, 32]           2,048
             ReLU-35         [-1, 1024, 32, 32]               0
           Conv2d-36         [-1, 1024, 32, 32]       9,437,184
      BatchNorm2d-37         [-1, 1024, 32, 32]           2,048
             ReLU-38         [-1, 1024, 32, 32]               0
       DoubleConv-39         [-1, 1024, 32, 32]               0
  ConvTranspose2d-40          [-1, 512, 64, 64]       2,097,664
           Conv2d-41          [-1, 512, 64, 64]       4,718,592
      BatchNorm2d-42          [-1, 512, 64, 64]           1,024
             ReLU-43          [-1, 512, 64, 64]               0
           Conv2d-44          [-1, 512, 64, 64]       2,359,296
      BatchNorm2d-45          [-1, 512, 64, 64]           1,024
             ReLU-46          [-1, 512, 64, 64]               0
       DoubleConv-47          [-1, 512, 64, 64]               0
  ConvTranspose2d-48        [-1, 256, 128, 128]         524,544
           Conv2d-49        [-1, 256, 128, 128]       1,179,648
      BatchNorm2d-50        [-1, 256, 128, 128]             512
             ReLU-51        [-1, 256, 128, 128]               0
           Conv2d-52        [-1, 256, 128, 128]         589,824
      BatchNorm2d-53        [-1, 256, 128, 128]             512
             ReLU-54        [-1, 256, 128, 128]               0
       DoubleConv-55        [-1, 256, 128, 128]               0
  ConvTranspose2d-56        [-1, 128, 256, 256]         131,200
           Conv2d-57        [-1, 128, 256, 256]         294,912
      BatchNorm2d-58        [-1, 128, 256, 256]             256
             ReLU-59        [-1, 128, 256, 256]               0
           Conv2d-60        [-1, 128, 256, 256]         147,456
      BatchNorm2d-61        [-1, 128, 256, 256]             256
             ReLU-62        [-1, 128, 256, 256]               0
       DoubleConv-63        [-1, 128, 256, 256]               0
  ConvTranspose2d-64         [-1, 64, 512, 512]          32,832
           Conv2d-65         [-1, 64, 512, 512]          73,728
      BatchNorm2d-66         [-1, 64, 512, 512]             128
             ReLU-67         [-1, 64, 512, 512]               0
           Conv2d-68         [-1, 64, 512, 512]          36,864
      BatchNorm2d-69         [-1, 64, 512, 512]             128
             ReLU-70         [-1, 64, 512, 512]               0
       DoubleConv-71         [-1, 64, 512, 512]               0
           Conv2d-72          [-1, 1, 512, 512]              65
================================================================
Total params: 31,036,481
Trainable params: 31,036,481
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 3718.00
Params size (MB): 118.39
Estimated Total Size (MB): 3837.39
----------------------------------------------------------------`

你可能感兴趣的:(卷积神经网络,python,cv)