数据集1 - 数据量少,但数据相似度非常高 - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。
数据集2 - 数据量少,数据相似度低 - 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。
数据集3 - 数据量大,数据相似度低 - 在这种情况下,由于我们有一个大的数据集,我们的神经网络训练将会很有效。但是,由于我们的数据与用于训练我们的预训练模型的数据相比有很大不同。使用预训练模型进行的预测不会有效。因此,最好根据你的数据从头开始训练神经网络(Training from scatch)
数据集4 - 数据量大,数据相似度高 - 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。
import torchvision.models as models
resnet34 = models.resnet34(pretrained=True)
如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.models as models
from torchinfo import summary
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
max_epochs = 2
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# 数据读取
from torchvision import datasets
train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_cifar_dataset,
batch_size=batch_size, num_workers=4,
shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_cifar_dataset,
batch_size=batch_size, num_workers=4,
# 下载预训练模型 restnet50
resnet34 = models.resnet34(pretrained=True)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\xulele/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:10<00:00, 8.57MB/s]
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
summary(resnet34, (1, 3, 224, 224))
Layer (type:depth-idx) Output Shape Param #
ResNet [1, 1000] --
├─Conv2d: 1-1 [1, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128
├─ReLU: 1-3 [1, 64, 112, 112] --
├─MaxPool2d: 1-4 [1, 64, 56, 56] --
├─Sequential: 1-5 [1, 64, 56, 56] --
│ └─BasicBlock: 2-1 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-1 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [1, 64, 56, 56] --
│ └─BasicBlock: 2-2 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-8 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-9 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-10 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-11 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-12 [1, 64, 56, 56] --
│ └─BasicBlock: 2-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-13 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-14 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-15 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-16 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-17 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-18 [1, 64, 56, 56] --
├─Sequential: 1-6 [1, 128, 28, 28] --
│ └─BasicBlock: 2-4 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-19 [1, 128, 28, 28] 73,728
│ │ └─BatchNorm2d: 3-20 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-21 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-22 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-23 [1, 128, 28, 28] 256
│ │ └─Sequential: 3-24 [1, 128, 28, 28] 8,448
│ │ └─ReLU: 3-25 [1, 128, 28, 28] --
│ └─BasicBlock: 2-5 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-26 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-27 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-28 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-29 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-30 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-31 [1, 128, 28, 28] --
│ └─BasicBlock: 2-6 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-32 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-33 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-34 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-35 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-36 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-37 [1, 128, 28, 28] --
│ └─BasicBlock: 2-7 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-38 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-39 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-40 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-41 [1, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-42 [1, 128, 28, 28] 256
│ │ └─ReLU: 3-43 [1, 128, 28, 28] --
├─Sequential: 1-7 [1, 256, 14, 14] --
│ └─BasicBlock: 2-8 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-44 [1, 256, 14, 14] 294,912
│ │ └─BatchNorm2d: 3-45 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-46 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-47 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-48 [1, 256, 14, 14] 512
│ │ └─Sequential: 3-49 [1, 256, 14, 14] 33,280
│ │ └─ReLU: 3-50 [1, 256, 14, 14] --
│ └─BasicBlock: 2-9 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-51 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-52 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-53 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-54 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-55 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-56 [1, 256, 14, 14] --
│ └─BasicBlock: 2-10 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-57 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-58 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-59 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-60 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-61 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-62 [1, 256, 14, 14] --
│ └─BasicBlock: 2-11 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-63 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-64 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-65 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-66 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-67 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-68 [1, 256, 14, 14] --
│ └─BasicBlock: 2-12 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-69 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-70 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-71 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-72 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-73 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-74 [1, 256, 14, 14] --
│ └─BasicBlock: 2-13 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-75 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-76 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-77 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-78 [1, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-79 [1, 256, 14, 14] 512
│ │ └─ReLU: 3-80 [1, 256, 14, 14] --
├─Sequential: 1-8 [1, 512, 7, 7] --
│ └─BasicBlock: 2-14 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-81 [1, 512, 7, 7] 1,179,648
│ │ └─BatchNorm2d: 3-82 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-83 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-84 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-85 [1, 512, 7, 7] 1,024
│ │ └─Sequential: 3-86 [1, 512, 7, 7] 132,096
│ │ └─ReLU: 3-87 [1, 512, 7, 7] --
│ └─BasicBlock: 2-15 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-88 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-89 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-90 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-91 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-92 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-93 [1, 512, 7, 7] --
│ └─BasicBlock: 2-16 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-94 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-95 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-96 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-97 [1, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-98 [1, 512, 7, 7] 1,024
│ │ └─ReLU: 3-99 [1, 512, 7, 7] --
├─AdaptiveAvgPool2d: 1-9 [1, 512, 1, 1] --
├─Linear: 1-10 [1, 1000] 513,000
Total params: 21,797,672
Trainable params: 21,797,672
Non-trainable params: 0
Total mult-adds (G): 3.66
Input size (MB): 0.60
Forward/backward pass size (MB): 59.82
Params size (MB): 87.19
Estimated Total Size (MB): 147.61
#检测 模型准确率
def cal_predict_correct(model):
test_total_correct = 0
for iter,(images,labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
test_total_correct += (outputs.argmax(1) == labels).sum().item()
# print("test_total_correct: "+ str(test_total_correct))
return test_total_correct
total_correct = cal_predict_correct(resnet34)
print("test_total_correct: "+ str(test_total_correct / 10000))
test_total_correct: 0.1
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
# 冻结参数的梯度
feature_extract = True
new_model = resnet34
set_parameter_requires_grad(new_model, feature_extract)
# 修改模型
num_ftrs = new_model.fc.in_features
new_model.fc = nn.Linear(in_features=num_ftrs, out_features=10, bias=True)
summary(new_model, (1, 3, 224, 224))
Layer (type:depth-idx) Output Shape Param #
ResNet [1, 10] --
├─Conv2d: 1-1 [1, 64, 112, 112] (9,408)
├─BatchNorm2d: 1-2 [1, 64, 112, 112] (128)
├─ReLU: 1-3 [1, 64, 112, 112] --
├─MaxPool2d: 1-4 [1, 64, 56, 56] --
├─Sequential: 1-5 [1, 64, 56, 56] --
│ └─BasicBlock: 2-1 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-1 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-6 [1, 64, 56, 56] --
│ └─BasicBlock: 2-2 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-8 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-9 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-10 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-11 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-12 [1, 64, 56, 56] --
│ └─BasicBlock: 2-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-13 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-14 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-15 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-16 [1, 64, 56, 56] (36,864)
│ │ └─BatchNorm2d: 3-17 [1, 64, 56, 56] (128)
│ │ └─ReLU: 3-18 [1, 64, 56, 56] --
├─Sequential: 1-6 [1, 128, 28, 28] --
│ └─BasicBlock: 2-4 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-19 [1, 128, 28, 28] (73,728)
│ │ └─BatchNorm2d: 3-20 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-21 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-22 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-23 [1, 128, 28, 28] (256)
│ │ └─Sequential: 3-24 [1, 128, 28, 28] (8,448)
│ │ └─ReLU: 3-25 [1, 128, 28, 28] --
│ └─BasicBlock: 2-5 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-26 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-27 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-28 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-29 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-30 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-31 [1, 128, 28, 28] --
│ └─BasicBlock: 2-6 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-32 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-33 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-34 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-35 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-36 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-37 [1, 128, 28, 28] --
│ └─BasicBlock: 2-7 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-38 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-39 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-40 [1, 128, 28, 28] --
│ │ └─Conv2d: 3-41 [1, 128, 28, 28] (147,456)
│ │ └─BatchNorm2d: 3-42 [1, 128, 28, 28] (256)
│ │ └─ReLU: 3-43 [1, 128, 28, 28] --
├─Sequential: 1-7 [1, 256, 14, 14] --
│ └─BasicBlock: 2-8 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-44 [1, 256, 14, 14] (294,912)
│ │ └─BatchNorm2d: 3-45 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-46 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-47 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-48 [1, 256, 14, 14] (512)
│ │ └─Sequential: 3-49 [1, 256, 14, 14] (33,280)
│ │ └─ReLU: 3-50 [1, 256, 14, 14] --
│ └─BasicBlock: 2-9 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-51 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-52 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-53 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-54 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-55 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-56 [1, 256, 14, 14] --
│ └─BasicBlock: 2-10 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-57 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-58 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-59 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-60 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-61 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-62 [1, 256, 14, 14] --
│ └─BasicBlock: 2-11 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-63 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-64 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-65 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-66 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-67 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-68 [1, 256, 14, 14] --
│ └─BasicBlock: 2-12 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-69 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-70 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-71 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-72 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-73 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-74 [1, 256, 14, 14] --
│ └─BasicBlock: 2-13 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-75 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-76 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-77 [1, 256, 14, 14] --
│ │ └─Conv2d: 3-78 [1, 256, 14, 14] (589,824)
│ │ └─BatchNorm2d: 3-79 [1, 256, 14, 14] (512)
│ │ └─ReLU: 3-80 [1, 256, 14, 14] --
├─Sequential: 1-8 [1, 512, 7, 7] --
│ └─BasicBlock: 2-14 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-81 [1, 512, 7, 7] (1,179,648)
│ │ └─BatchNorm2d: 3-82 [1, 512, 7, 7] (1,024)
│ │ └─ReLU: 3-83 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-84 [1, 512, 7, 7] (2,359,296)
│ │ └─BatchNorm2d: 3-85 [1, 512, 7, 7] (1,024)
│ │ └─Sequential: 3-86 [1, 512, 7, 7] (132,096)
│ │ └─ReLU: 3-87 [1, 512, 7, 7] --
│ └─BasicBlock: 2-15 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-88 [1, 512, 7, 7] (2,359,296)
│ │ └─BatchNorm2d: 3-89 [1, 512, 7, 7] (1,024)
│ │ └─ReLU: 3-90 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-91 [1, 512, 7, 7] (2,359,296)
│ │ └─BatchNorm2d: 3-92 [1, 512, 7, 7] (1,024)
│ │ └─ReLU: 3-93 [1, 512, 7, 7] --
│ └─BasicBlock: 2-16 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-94 [1, 512, 7, 7] (2,359,296)
│ │ └─BatchNorm2d: 3-95 [1, 512, 7, 7] (1,024)
│ │ └─ReLU: 3-96 [1, 512, 7, 7] --
│ │ └─Conv2d: 3-97 [1, 512, 7, 7] (2,359,296)
│ │ └─BatchNorm2d: 3-98 [1, 512, 7, 7] (1,024)
│ │ └─ReLU: 3-99 [1, 512, 7, 7] --
├─AdaptiveAvgPool2d: 1-9 [1, 512, 1, 1] --
├─Linear: 1-10 [1, 10] 5,130
Total params: 21,289,802
Trainable params: 5,130
Non-trainable params: 21,284,672
Total mult-adds (G): 3.66
Input size (MB): 0.60
Forward/backward pass size (MB): 59.81
Params size (MB): 85.16
Estimated Total Size (MB): 145.57
Resnet34_new = new_model.to(device)
# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数:自定义损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs
total_step = len(train_loader)
train_all_loss = []
test_all_loss = []
for i in range(epoch):
train_total_loss = 0
train_total_num = 0
train_total_correct = 0
for iter, (images,labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = Resnet34_new(images)
loss = criterion(outputs,labels)
train_total_correct += (outputs.argmax(1) == labels).sum().item()
train_total_num += labels.shape[0]
train_total_loss += loss.item()
print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
test_total_loss = 0
test_total_correct = 0
test_total_num = 0
for iter,(images,labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = Resnet34_new(images)
loss = criterion(outputs,labels)
test_total_correct += (outputs.argmax(1) == labels).sum().item()
test_total_loss += loss.item()
test_total_num += labels.shape[0]
print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100
train_all_loss.append(np.round(train_total_loss / train_total_num,4))
test_all_loss.append(np.round(test_total_loss / test_total_num,4))
Epoch [1/2], Iter [1481/3125], train_loss:0.17220
在可以使用的显存数量一定的情况下,每次训练能够加载的数据更多(也就是batch size更大),则也可以提高训练效率
1、引入 from torch.cuda.amp import autocast
2、forward函数指定 autocast 装饰器
3、训练过程: 只需在将数据输入模型及其之后的部分放入“with autocast():“
from torch.cuda.amp import autocast
# forward指定装饰器
def forward(self, x):
return x
# 指定with autocast
for x in train_loader:
x = x.cuda()
with autocast():
output = model(x)
from torch.cuda.amp import autocast
class DemoModel(nn.Module):
def init(self):
super(DemoModel, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
device = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)
half_model = DemoModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(Resnet50_new.parameters(), lr=lr)
epoch = max_epochs
total_step = len(train_loader)
train_all_loss = []
test_all_loss = []
for i in range(epoch):
train_total_loss = 0
train_total_num = 0
train_total_correct = 0
for iter, (images,labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
with autocast():
outputs = half_model(images)
loss = criterion(outputs,labels)
train_total_correct += (outputs.argmax(1) == labels).sum().item()
train_total_num += labels.shape[0]
train_total_loss += loss.item()
print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))
test_total_loss = 0
test_total_correct = 0
test_total_num = 0
for iter,(images,labels) in enumerate(test_loader):
images = images.to(device)
labels = labels.to(device)
with autocast():
outputs = half_model(images)
loss = criterion(outputs,labels)
test_total_correct += (outputs.argmax(1) == labels).sum().item()
test_total_loss += loss.item()
test_total_num += labels.shape[0]
print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(
i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100
train_all_loss.append(np.round(train_total_loss / train_total_num,4))
test_all_loss.append(np.round(test_total_loss / test_total_num,4))