Pytorch保存与加载模型

原博文地址:https://blog.csdn.net/weixin_41278720/article/details/80759933

Pytorch中的torchvision包又包括3个子包,分别如下:

torchvison.datasets:预定义好的数据集(比如MNIST、CIFAR10等)

torchvision.models :预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)

torchvision.transforms :预定义好的数据增强方法(比如Resize、ToTensor等)

models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

1、加载官方提供的网络模型

加载resnet50预训练模型(包含训练得到的权重与偏置参数)

import torchvision.models as models
 
resnet50 = models.resnet50(pretrained=True)

只加载resnet50网络结构,并未使用预训练模型的参数对其初始化(权重与偏置都是随机值)

import torchvision.models as models
 
resnet50 = models.resnet50(pretrained=False)

2、保存、加载自己的网络模型

方法一(推荐):只保存和加载模型中的参数,不保存其网络结构

保存:将训练参数保存在ckp文件夹中,文件名:model.pth

torch.save(resnet50.state_dict(),'ckp/model.pth') 

加载:这里的resnet50是我们自己实现的网络,因此可以不必传递pretrained=True参数(官方提供的版本需要传递此参数)

resnet=resnet50()    #加载网络结构
resnet.load_state_dict(torch.load('ckp/model.pth'))  #加载该网络结构的预训练参数

方法二:保存、加载网络的结构和参数信息

保存

torch.save(resnet50,'model.pth') 

加载

resnet50 = torch.load('model.pth')

方法三:选择保存、加载网络中的一部分参数或者保存额外的参数

保存

save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
torch.save({
      'session': args.session,
      'epoch': epoch + 1,
      'model': fasterRCNN.module.state_dict() if args.mGPUs else fasterRCNN.state_dict(),
      'optimizer': optimizer.state_dict(),
      'pooling_mode': cfg.POOLING_MODE,
      'class_agnostic': args.class_agnostic,
    }, save_name)

加载

load_name = os.path.join(output_dir,
      'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
checkpoint = torch.load(load_name)
args.session = checkpoint['session']
args.start_epoch = checkpoint['epoch']
fasterRCNN.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = optimizer.param_groups[0]['lr']
cfg.POOLING_MODE = checkpoint['pooling_mode']

 

你可能感兴趣的:(#,Pytorch)