【PyTorch】PyTorch模型定义

本节内容学习了:

  1. 模型自定义的三种方式:Sequential, Modile List,Module dict;
  2. 当模型有重复出现的层结构,我们可以构建模型块实现复用来构建复杂模型;
  3. 如果要对模型进行修改:通过实例化来修改特定的层、通过在forward中增加参数来增加输入(要注意修改函数体内增加的输入是如何起作用的、对应的模型的定义也要进行修改)、通过return特定的值来实现额外输出;
  4. 模型的读取和保存要考虑2种训练模式2种保存类型。

1 模型自定义的三种方式:

  • Sequential:层的排列,适用于层数较少的情况,但不需要定义forward。

Direct List:直接列出

## Sequential: Direct list
import torch.nn as nn
net1 = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net1)
Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

Ordered Dict:有序字典(Python中的字典无序,但网络是有序的,因此用Ordered Dict)

## Sequential: Ordered Dict
import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)), # fully connect
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
          ]))
print(net2)
Sequential(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

用OrderedDict可以实现层的命名。

  • Module List :对于模块比较重复的网络比较实用,但是需要定义forward并实例化。

一个错误展示:

# ModuleList
net3 = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])  #[nn.Linear(x,x) for i in range(5)]
net3.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net3[-1])  # 类似List的索引访问,输出最后一层
print(net3)

注意:ModuleList定义后并没有形成网络,只是将不同的模块存储到一起,因此不能直接输入数据,因此还需要定义forward()。正确方式如下:

class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.modulelist = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
        self.modulelist.append(nn.Linear(256, 10))
    
    def forward(self, x):
        for layer in self.modulelist:
            x = layer(x)
        return x
net3_ = Net3()
out3_ = net3_(a)
print(out3_.shape)
  • Module Dict:对于模块比较重复的网络比较实用,但是需要定义forward并实例化。
## ModuleDict
net4 = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net4['output'] = nn.Linear(256, 10) # 添加
print(net4['linear']) # 访问
print(net4.output)

同样地,ModuleDict并没有定义一个网络,它只是将不同的模块储存在一起。此处应报错。正确使用方式同上。

2 利用模型块快速搭建复杂网络

重复出现的层成为模块。以U-Net为例,介绍如何构建模型块,以及如何利用模型块快速搭建复杂模型。
【PyTorch】PyTorch模型定义_第1张图片

组成U-Net的模型块主要有如下几个部分:
1)每个子块内部的两次卷积(Double Convolution)
2)左侧模型块之间的下采样连接,即最大池化(Max pooling)
3)右侧模型块之间的上采样连接(Up sampling)
4)输出层的处理

除模型块外,还有模型块之间的横向连接,输入和U-Net底部的连接等计算,这些单独的操作可以通过forward函数来实现。 (参考:https://github.com/milesial/Pytorch-UNet )

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
## 组装
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
unet = UNet(3,1) #实例化
unet

3 模型修改

  • 修改特定的层
## 修改特定层
import copy
unet1 = copy.deepcopy(unet)#复制
unet1.outc

尝试进行修改:

b = torch.rand(1,3,224,224)#1个batch,3个通道,大小224*224
out_unet1 = unet1(b)
print(out_unet1.shape)
torch.Size([1, 1, 224, 224])

这里输出了1个通道,如果我想让它输出5个通道该如何做呢?

unet1.outc = OutConv(64, 5) #重新实例化
unet1.outc
OutConv(
  (conv): Conv2d(64, 5, kernel_size=(1, 1), stride=(1, 1))
)

这里,输出就变成了5个通道。

  • 添加额外输入
    只需要修改两个地方:forward中增加一个输入,函数体内根据需要修改add_variable如何起作用的;修改对应的模型的定义。
## 添加额外输入
class UNet2(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet2, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x, add_variable):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = x + add_variable   #修改点
        logits = self.outc(x)
        return logits
unet2 = UNet2(3,1)

c = torch.rand(1,1,224,224)
out_unet2 = unet2(b, c)
print(out_unet2.shape)
  • 添加额外的输出
    只需要修改return。假设想利用Unet的bottom layer做分类,只需要return中增加x5。
## 添加额外输出
class UNet3(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet3, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits, x5  # 修改点
unet3 = UNet3(3,1)

c = torch.rand(1,1,224,224)
out_unet3, mid_out = unet3(b)
print(out_unet3.shape, mid_out.shape)

模型保存和读取

  • 模型保存
    需要考虑两种两种训练模式下的两种保存类型:
    • 两种训练模式:单卡、多卡
    • 两种保存类型:保存整个模型、保存模型权重

unet.state_dict()用字典的方式将每个层的权重保存下来,

## CPU或单卡:保存&读取整个模型
torch.save(unet, "./unet_example.pth")
loaded_unet = torch.load("./unet_example.pth")
loaded_unet.state_dict()
## CPU或单卡:保存&读取模型权重
torch.save(unet.state_dict(), "./unet_weight_example.pth")
loaded_unet_weights = torch.load("./unet_weight_example.pth")
unet.load_state_dict(loaded_unet_weights)
unet.state_dict()
## 多卡:保存&读取整个模型。注意模型层名称前多了module
## 不建议,因为保存模型的GPU_id等信息和读取后训练环境可能不同,尤其是要把保存的模型交给另一用户使用的情况
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
unet_mul = copy.deepcopy(unet)
unet_mul = nn.DataParallel(unet_mul).cuda()
unet_mul
torch.save(unet_mul, "./unet_mul_example.pth")
loaded_unet_mul = torch.load("./unet_mul_example.pth")
loaded_unet_mul
## 多卡:保存&读取模型权重。
torch.save(unet_mul.state_dict(), "./unet_weight_mul_example.pth")
loaded_unet_weights_mul = torch.load("./unet_weight_mul_example.pth")
unet_mul.load_state_dict(loaded_unet_weights_mul)
unet_mul = nn.DataParallel(unet_mul).cuda()
unet_mul.state_dict()
# 另外,如果保存的是整个模型,也建议采用提取权重的方式构建新的模型:
unet_mul.state_dict = loaded_unet_mul.state_dict
unet_mul = nn.DataParallel(unet_mul).cuda()
unet_mul.state_dict()

你可能感兴趣的:(PyTorch,Python,pytorch,深度学习,python)