最近在学习U-Net,读完文章之后尝试搭建模型的框架,阅读了前人的模型后,试着自己搭建了一下,适合初学者。
import torch
import torch.nn as nn
from torchsummary import summary
# 2次 3*3卷积
# Remark,第一个3*3卷积承担,升降维的功能
class DoubleConv(nn.Module):
def __init__(self, in_channels,mid_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),
# Same conv. not valid conv. in original paper
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
# 定义主网络框架
class UNet(nn.Module):
def __init__(self, original_channels=1, num_classes=1):
super(UNet, self).__init__()
self.original_channels = original_channels # 输入可能是RGB 3通道,也有可能是灰度图 1通道,所以定义original_channels;
self.num_classes = num_classes # 输出可能是二分类(前后景),也可能是多分类,所以定义num_classes
# Contracting path: 卷积2次DoubleCon,下采样Maxpool 1次,一共编码5次
self.encoder1 = DoubleConv(self.original_channels,mid_channels=64, out_channels=64)
self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = DoubleConv(in_channels=64, mid_channels=128,out_channels=128)
self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = DoubleConv(in_channels=128, mid_channels=256,out_channels=256)
self.down3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = DoubleConv(in_channels=256, mid_channels=512,out_channels=512)
self.down4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder5 = DoubleConv(in_channels=512,mid_channels=1024, out_channels=1024)
# Expansive path: 上采样ConvTranspose 1次,卷积2次DoubleCon,一共解码5次,最后一次为1*1Conv.
# Remark:通道拼接放在正向传播中做,注意编码和上采样的channels匹配的问题
self.up1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
self.decoder1 = DoubleConv(in_channels=1024, mid_channels=512,out_channels=512)
self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
self.decoder2 = DoubleConv(in_channels=512,mid_channels=256, out_channels=256)
self.up3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
self.decoder3 = DoubleConv(in_channels=256,mid_channels=128, out_channels=128)
self.up4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
self.decoder4 = DoubleConv(in_channels=128, mid_channels=64,out_channels=64)
self.decoder5 = nn.Conv2d(64, num_classes, kernel_size=1)
# 定义正向传播的过程
def forward(self, x):
encoder1 = self.encoder1(x) # original_channel → 64_channels 512*512 → 512*512
encoder1_pool = self.down1(encoder1) # 64_channels → 64_channels 512*512 → 256*256
encoder2 = self.encoder2(encoder1_pool) # 64_channels → 128_channels 256*256 → 256*256
encoder2_pool = self.down2(encoder2) # 128_channels → 128_channels 256*256 → 128*128
encoder3 = self.encoder3(encoder2_pool) # 128_channels → 256_channels 128*128 → 128*128
encoder3_pool = self.down3(encoder3) # 256_channels → 256_channels 128*128 → 64*64
encoder4 = self.encoder4(encoder3_pool) # 256_channels → 512_channels 64*64 → 64*64
encoder4_pool = self.down4(encoder4) # 512_channels → 512_channels 64*64 → 32*32
encoder5 = self.encoder5(encoder4_pool) # 512_channels → 1024_channels 32*32 → 32*32
decoder1_up = self.up1(encoder5) # 1024_channels → 512_channels 32*32 → 64*64
decoder1 = self.decoder1(torch.cat((encoder4, decoder1_up), dim=1))
# 512+512_channels → 512_channels 64*64 → 64*64
decoder2_up = self.up2(decoder1) # 512_channels → 256_channels 64*64 → 128*128
decoder2 = self.decoder2(torch.cat((encoder3, decoder2_up), dim=1))
# 256+256_channels → 256_channels 128*128 → 128*128
decoder3_up = self.up3(decoder2) # 256_channels → 128_channels 128*64 → 256*256
decoder3 = self.decoder3(torch.cat((encoder2, decoder3_up), dim=1))
# 128+128_channels → 128_channels 256*256 → 256*256
decoder4_up = self.up4(decoder3) # 128_channels → 64_channels 256*256 → 256*256
decoder4 = self.decoder4(torch.cat((encoder1, decoder4_up), dim=1))
# 64+64_channels → 64_channels 256*256 → 512*512
out = self.decoder5(decoder4) # 64_channels → num_classes channels 512*512 → 512*512
return out
"""
下面三行代码是验证模型架构用,实际不需要
"""
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# unet = UNet().to(device)
# summary(unet, (1, 512, 512))
使用pytorch的summary功能,如果没有可以pip install summary
三行测试代码开启后,可以查看模型的架构如下
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 512, 512] 576
BatchNorm2d-2 [-1, 64, 512, 512] 128
ReLU-3 [-1, 64, 512, 512] 0
Conv2d-4 [-1, 64, 512, 512] 36,864
BatchNorm2d-5 [-1, 64, 512, 512] 128
ReLU-6 [-1, 64, 512, 512] 0
DoubleConv-7 [-1, 64, 512, 512] 0
MaxPool2d-8 [-1, 64, 256, 256] 0
Conv2d-9 [-1, 128, 256, 256] 73,728
BatchNorm2d-10 [-1, 128, 256, 256] 256
ReLU-11 [-1, 128, 256, 256] 0
Conv2d-12 [-1, 128, 256, 256] 147,456
BatchNorm2d-13 [-1, 128, 256, 256] 256
ReLU-14 [-1, 128, 256, 256] 0
DoubleConv-15 [-1, 128, 256, 256] 0
MaxPool2d-16 [-1, 128, 128, 128] 0
Conv2d-17 [-1, 256, 128, 128] 294,912
BatchNorm2d-18 [-1, 256, 128, 128] 512
ReLU-19 [-1, 256, 128, 128] 0
Conv2d-20 [-1, 256, 128, 128] 589,824
BatchNorm2d-21 [-1, 256, 128, 128] 512
ReLU-22 [-1, 256, 128, 128] 0
DoubleConv-23 [-1, 256, 128, 128] 0
MaxPool2d-24 [-1, 256, 64, 64] 0
Conv2d-25 [-1, 512, 64, 64] 1,179,648
BatchNorm2d-26 [-1, 512, 64, 64] 1,024
ReLU-27 [-1, 512, 64, 64] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,296
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
DoubleConv-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 32, 32] 4,718,592
BatchNorm2d-34 [-1, 1024, 32, 32] 2,048
ReLU-35 [-1, 1024, 32, 32] 0
Conv2d-36 [-1, 1024, 32, 32] 9,437,184
BatchNorm2d-37 [-1, 1024, 32, 32] 2,048
ReLU-38 [-1, 1024, 32, 32] 0
DoubleConv-39 [-1, 1024, 32, 32] 0
ConvTranspose2d-40 [-1, 512, 64, 64] 2,097,664
Conv2d-41 [-1, 512, 64, 64] 4,718,592
BatchNorm2d-42 [-1, 512, 64, 64] 1,024
ReLU-43 [-1, 512, 64, 64] 0
Conv2d-44 [-1, 512, 64, 64] 2,359,296
BatchNorm2d-45 [-1, 512, 64, 64] 1,024
ReLU-46 [-1, 512, 64, 64] 0
DoubleConv-47 [-1, 512, 64, 64] 0
ConvTranspose2d-48 [-1, 256, 128, 128] 524,544
Conv2d-49 [-1, 256, 128, 128] 1,179,648
BatchNorm2d-50 [-1, 256, 128, 128] 512
ReLU-51 [-1, 256, 128, 128] 0
Conv2d-52 [-1, 256, 128, 128] 589,824
BatchNorm2d-53 [-1, 256, 128, 128] 512
ReLU-54 [-1, 256, 128, 128] 0
DoubleConv-55 [-1, 256, 128, 128] 0
ConvTranspose2d-56 [-1, 128, 256, 256] 131,200
Conv2d-57 [-1, 128, 256, 256] 294,912
BatchNorm2d-58 [-1, 128, 256, 256] 256
ReLU-59 [-1, 128, 256, 256] 0
Conv2d-60 [-1, 128, 256, 256] 147,456
BatchNorm2d-61 [-1, 128, 256, 256] 256
ReLU-62 [-1, 128, 256, 256] 0
DoubleConv-63 [-1, 128, 256, 256] 0
ConvTranspose2d-64 [-1, 64, 512, 512] 32,832
Conv2d-65 [-1, 64, 512, 512] 73,728
BatchNorm2d-66 [-1, 64, 512, 512] 128
ReLU-67 [-1, 64, 512, 512] 0
Conv2d-68 [-1, 64, 512, 512] 36,864
BatchNorm2d-69 [-1, 64, 512, 512] 128
ReLU-70 [-1, 64, 512, 512] 0
DoubleConv-71 [-1, 64, 512, 512] 0
Conv2d-72 [-1, 1, 512, 512] 65
================================================================
Total params: 31,036,481
Trainable params: 31,036,481
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.00
Forward/backward pass size (MB): 3718.00
Params size (MB): 118.39
Estimated Total Size (MB): 3837.39
----------------------------------------------------------------`