5.1 Pytorch模型的定义的方式


  • Sequential
  • Ordered Dict
  • ModuleList
  • ModuleDict

5.1.2 Sequential

## Sequential: Direct list
import torch.nn as nn
net1 = nn.Sequential(
        nn.Linear(784, 256),
        nn.Linear(256, 10), 
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
## Sequential: Ordered Dict
import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)),
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)

a = torch.rand(4,784)
out1 = net1(a)
out2 = net2(a)
print(out1.shape==out2.shape, out1.shape) #True表明这两个结构相同
True torch.Size([4, 10])

5.1.2 ModuleList

net3 = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net3.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net3[-1])  # 类似List的索引访问
Linear(in_features=256, out_features=10, bias=True)
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
# 注意ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。这样就不可以
out3 = net3(a)

class Net3(nn.Module):
    def __init__(self):
        self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
        self.modulelist.append(nn.Linear(256, 10))
    def forward(self, x):
        for layer in self.modulelist:
            x = layer(x)
        return x
net3_ = Net3()
out3_ = net3_(a)
torch.Size([4, 10])

5.1.3 ModuleDict

net4 = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
net4['output'] = nn.Linear(256, 10) # 添加
print(net4['linear']) # 访问
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
# 同样地,ModuleDict并没有定义一个网络,它只是将不同的模块储存在一起。此处应报错。
# 正确使用方式同上
out4 = net4(a)

5.2 利用模型块快速搭建复杂网络

1)每个子块内部的两次卷积(Double Convolution)
2)左侧模型块之间的下采样连接,即最大池化(Max pooling)
3)右侧模型块之间的上采样连接(Up sampling)

(参考:https://github.com/milesial/Pytorch-UNet )

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        self.maxpool_conv = nn.Sequential(
            DoubleConv(in_channels, out_channels)

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
## 组装
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
unet = UNet(3,1)
5.3 pytorch模型修改


  • 添加额外输入
  • 添加额外输出

5.3.1 修改特定层

import copy
unet1 = copy.deepcopy(unet)
  (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
b = torch.rand(1,3,224,224)
out_unet1 = unet1(b)
torch.Size([1, 1, 224, 224])
unet1.outc = OutConv(64, 5)
  (conv): Conv2d(64, 5, kernel_size=(1, 1), stride=(1, 1))
out_unet1 = unet1(b)
torch.Size([1, 5, 224, 224])

5.3.2 添加额外输入

class UNet2(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet2, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x, add_variable):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = x + add_variable   #修改点
        logits = self.outc(x)
        return logits
unet2 = UNet2(3,1)

c = torch.rand(1,1,224,224)
out_unet2 = unet2(b, c)
torch.Size([1, 1, 224, 224])


class UNet3(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet3, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits, x5  # 修改点
unet3 = UNet3(3,1)

c = torch.rand(1,1,224,224)
out_unet3, mid_out = unet3(b)
print(out_unet3.shape, mid_out.shape)
torch.Size([1, 1, 224, 224]) torch.Size([1, 512, 14, 14])

5.4 模型保存与读取


## 讲解点:回到jupyter的文件目录下,看保存的结果

5.4.1 模型的保存格式


5.4.2 单卡与多卡模型存储的区别

torch.save(unet, “./unet_example.pth”)
loaded_unet = torch.load(“./unet_example.pth”)

torch.save(unet.state_dict(), “./unet_weight_example.pth”)
loaded_unet_weights = torch.load(“./unet_weight_example.pth”)

os.environ[‘CUDA_VISIBLE_DEVICES’] = ‘2,3’
unet_mul = copy.deepcopy(unet)
unet_mul = nn.DataParallel(unet_mul).cuda()

torch.save(unet_mul, “./unet_mul_example.pth”)
loaded_unet_mul = torch.load(“./unet_mul_example.pth”)

torch.save(unet_mul.state_dict(), “./unet_weight_mul_example.pth”)
loaded_unet_weights_mul = torch.load(“./unet_weight_mul_example.pth”)
unet_mul = nn.DataParallel(unet_mul).cuda()

unet_mul.state_dict = loaded_unet_mul.state_dict
unet_mul = nn.DataParallel(unet_mul).cuda()
