Pytorch-day09-模型微调-checkpoint

模型微调(fine-tune)-迁移学习

  • torchvision微调
  • timm微调
  • 半精度训练

起源:

  • 1、随着深度学习的发展,模型的参数越来越大,许多开源模型都是在较大数据集上进行训练的,比如Imagenet-1k,Imagenet-11k等
  • 2、如果数据集可能只有几千张,训练几千万参数的大模型,过拟合无法避免
  • 3、如果我们想从零开始训练一个大模型,那么我们的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资⾦,成本无法承受

解决方案:

  • 应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上
  • 比如:ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成
  • 模型微调(finetune):就是先找到一个同类的别人训练好的模型,基于已经训练好的模型换成自己的数据,通过训练调整一下参数

不同数据集下使用微调:

  • 数据集1 - 数据量少,但数据相似度非常高 - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。

  • 数据集2 - 数据量少,数据相似度低 - 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。

  • 数据集3 - 数据量大,数据相似度低 - 在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)

  • 数据集4 - 数据量大,数据相似度高 - 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

微调的是什么?

  • 换数据源
  • 针对K层进行重新训练
  • K层的权重&shape调整

1、模型微调(fine-tune)一般流程:

  • 1、在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型
  • 2、创建一个新的神经网络模型,即目标模型,它复制了源模型上除了输出层外的所有模型设计及其参数
  • 3、为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数
  • 4、在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x43rJPAE-1692785269018)(attachment:image.png)]

2、torchvision微调

2.1 实例化Model

import torchvision.models as models
resnet34 = models.resnet34(pretrained=True)

pretrained参数说明:

  • 1、通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重
  • 2、当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重

注意:如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,否则会报错。

2.2 训练特定层

如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

2.3 实例

  • 使用resnet34为例的将1000类改为10类,但是仅改变最后一层的模型参数
  • 我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.models as models
from torchinfo import summary
#超参数定义
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 2
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets

#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
                   ])


train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)

#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset, 
                                           batch_size=batch_size, num_workers=4, 
                                           shuffle=True, drop_last=True)

test_loader = torch.utils.data.DataLoader(test_cifar_dataset, 
                                         batch_size=batch_size, num_workers=4, 
                                         shuffle=False)


# 下载预训练模型 restnet50
resnet34 = models.resnet34(pretrained=True)
print(resnet34)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\xulele/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:10<00:00, 8.57MB/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (4): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
#查看模型结构
summary(resnet34, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 21,797,672
Trainable params: 21,797,672
Non-trainable params: 0
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.82
Params size (MB): 87.19
Estimated Total Size (MB): 147.61
==========================================================================================
#检测 模型准确率
def cal_predict_correct(model):
    test_total_correct = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
    
        outputs = model(images)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
#     print("test_total_correct: "+ str(test_total_correct))
    return test_total_correct
total_correct = cal_predict_correct(resnet34)
print("test_total_correct: "+ str(test_total_correct / 10000))
test_total_correct: 0.1
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
            

# 冻结参数的梯度
feature_extract = True
new_model = resnet34
set_parameter_requires_grad(new_model, feature_extract)

# 修改模型
#训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层
num_ftrs = new_model.fc.in_features
new_model.fc = nn.Linear(in_features=num_ftrs, out_features=10, bias=True)


summary(new_model, (1, 3, 224, 224)) 
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 10]                   --
├─Conv2d: 1-1                            [1, 64, 112, 112]         (9,408)
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         (128)
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
│    └─BasicBlock: 2-3                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-14            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-15                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-16                 [1, 64, 56, 56]           (36,864)
│    │    └─BatchNorm2d: 3-17            [1, 64, 56, 56]           (128)
│    │    └─ReLU: 3-18                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-19                 [1, 128, 28, 28]          (73,728)
│    │    └─BatchNorm2d: 3-20            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-21                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-22                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-23            [1, 128, 28, 28]          (256)
│    │    └─Sequential: 3-24             [1, 128, 28, 28]          (8,448)
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-5                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-26                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-27            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-28                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-29                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-30            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-31                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-6                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-32                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-33            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-34                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-35                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-36            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-37                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-7                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-38                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-39            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-40                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-41                 [1, 128, 28, 28]          (147,456)
│    │    └─BatchNorm2d: 3-42            [1, 128, 28, 28]          (256)
│    │    └─ReLU: 3-43                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-8                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-44                 [1, 256, 14, 14]          (294,912)
│    │    └─BatchNorm2d: 3-45            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-46                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-47                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-48            [1, 256, 14, 14]          (512)
│    │    └─Sequential: 3-49             [1, 256, 14, 14]          (33,280)
│    │    └─ReLU: 3-50                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-9                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-51                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-52            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-53                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-54                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-55            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-56                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-10                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-57                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-58            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-59                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-60                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-61            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-62                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-11                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-63                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-64            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-65                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-66                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-67            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-68                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-12                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-69                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-70            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-71                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-72                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-73            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-74                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-13                  [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-75                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-76            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-77                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-78                 [1, 256, 14, 14]          (589,824)
│    │    └─BatchNorm2d: 3-79            [1, 256, 14, 14]          (512)
│    │    └─ReLU: 3-80                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-14                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-81                 [1, 512, 7, 7]            (1,179,648)
│    │    └─BatchNorm2d: 3-82            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-83                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-84                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-85            [1, 512, 7, 7]            (1,024)
│    │    └─Sequential: 3-86             [1, 512, 7, 7]            (132,096)
│    │    └─ReLU: 3-87                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-15                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-88                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-89            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-90                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-91                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-92            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-93                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-16                  [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-94                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-95            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-96                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-97                 [1, 512, 7, 7]            (2,359,296)
│    │    └─BatchNorm2d: 3-98            [1, 512, 7, 7]            (1,024)
│    │    └─ReLU: 3-99                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 10]                   5,130
==========================================================================================
Total params: 21,289,802
Trainable params: 5,130
Non-trainable params: 21,284,672
Total mult-adds (G): 3.66
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 59.81
Params size (MB): 85.16
Estimated Total Size (MB): 145.57
==========================================================================================
#训练&验证
Resnet34_new = new_model.to(device)
# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数:自定义损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs

total_step = len(train_loader)
train_all_loss = []
test_all_loss = []

for i in range(epoch):
    Resnet34_new.train()
    train_total_loss = 0
    train_total_num = 0
    train_total_correct = 0

    for iter, (images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        train_total_correct += (outputs.argmax(1) == labels).sum().item()
        
        #backword
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
    
    Resnet34_new.eval()
    test_total_loss = 0
    test_total_correct = 0
    test_total_num = 0
    for iter,(images,labels) in enumerate(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = Resnet34_new(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
    print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
        i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100
    
    ))
    train_all_loss.append(np.round(train_total_loss / train_total_num,4))
    test_all_loss.append(np.round(test_total_loss / test_total_num,4))

Epoch [1/2], Iter [1481/3125], train_loss:0.17220

半精度训练

问题:

GPU的性能主要分为两部分:算力和显存。
前者决定了显卡计算的速度,后者则决定了显卡可以同时放入多少数据用于计算
在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率
定义:

PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性
但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”
image.png

显然半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算

3.1、半精度训练的设置
1、引入 from torch.cuda.amp import autocast
2、forward函数指定 autocast 装饰器
3、训练过程: 只需在将数据输入模型及其之后的部分放入“with autocast():“
4、半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)

引入


from torch.cuda.amp import autocast

# forward指定装饰器
@autocast()   
def forward(self, x):
    ...
    return x

# 指定with autocast 
 for x in train_loader:
    x = x.cuda()
    with autocast():
            output = model(x)
        ...

半精度训练案例
from torch.cuda.amp import autocast

半精度模型

class DemoModel(nn.Module):
def init(self):
super(DemoModel, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

@autocast() 
def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x





#训练&验证

device = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)
half_model = DemoModel().to(device)

损失函数:自定义损失函数

criterion = nn.CrossEntropyLoss()

优化器

optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs

total_step = len(train_loader)
train_all_loss = []
test_all_loss = []

for i in range(epoch):
half_model.train()
train_total_loss = 0
train_total_num = 0
train_total_correct = 0

for iter, (images,labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
with autocast():
outputs = half_model(images)
loss = criterion(outputs,labels)
train_total_correct += (outputs.argmax(1) == labels).sum().item()

#backword
optimizer.zero_grad()
loss.backward()
optimizer.step()

        train_total_num += labels.shape[0]
        train_total_loss += loss.item()
        print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))

half_model.eval()
test_total_loss = 0
test_total_correct = 0
test_total_num = 0
for iter,(images,labels) in enumerate(test_loader):
    images = images.to(device)
    labels = labels.to(device)
    with autocast():
        outputs = half_model(images)
        loss = criterion(outputs,labels)
        test_total_correct += (outputs.argmax(1) == labels).sum().item()
        test_total_loss += loss.item()
        test_total_num += labels.shape[0]
        print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
            i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100

))
train_all_loss.append(np.round(train_total_loss / train_total_num,4))
test_all_loss.append(np.round(test_total_loss / test_total_num,4))

你可能感兴趣的:(pytorch,pytorch,人工智能,python)