我们用Pytroch 训练神经网络模型时,如果数据量稍大,或者网络过深,会造成训练速度过慢,这时,我们就需要保存中间结果,然后下次的时候,从中间结果恢复参数进行继续训练。
记录一下代码的主要写法:
import torch
#导入 torch模块
import argparse
# 一个命令行参数处理很好用的Python 库
parse = argparse.ArgumentParser(description='Pytorch CIFAR10 Training')
# 一般这一行添加代码的作用
parse = argparse.ArgumentParser('--lr' ,default = 0.1, type=float, help='learning rate')
# 定义学习率
parse = argparse.ArgumentParser('--resume' , '-r' , action='store_true', help='resume from checkpoint')
# 定义是否从检查点恢复模型
use_cuda = torch.cuda.is_available()
bset_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==>Preparing data...')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),])
# 定义训练集的数据增强
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 定义训练集的数据增强
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
# 用模块载入训练数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
# 用模块载入测试数据集
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 定义 类别信息
if args.resume:
checkpoint = torch.load('./checkpoint/ckpt.t7'
net = checkpoint ['net']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
if use_cuda: