原博文地址:https://blog.csdn.net/weixin_41278720/article/details/80759933
Pytorch中的torchvision包又包括3个子包,分别如下:
torchvison.datasets:预定义好的数据集(比如MNIST、CIFAR10等)
torchvision.models :预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)
torchvision.transforms :预定义好的数据增强方法(比如Resize、ToTensor等)
models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
加载resnet50预训练模型(包含训练得到的权重与偏置参数)
import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
只加载resnet50网络结构,并未使用预训练模型的参数对其初始化(权重与偏置都是随机值)
import torchvision.models as models
resnet50 = models.resnet50(pretrained=False)
方法一(推荐):只保存和加载模型中的参数,不保存其网络结构
保存:将训练参数保存在ckp文件夹中,文件名:model.pth
torch.save(resnet50.state_dict(),'ckp/model.pth')
加载:这里的resnet50是我们自己实现的网络,因此可以不必传递pretrained=True参数(官方提供的版本需要传递此参数)
resnet=resnet50() #加载网络结构
resnet.load_state_dict(torch.load('ckp/model.pth')) #加载该网络结构的预训练参数
方法二:保存、加载网络的结构和参数信息
保存
torch.save(resnet50,'model.pth')
加载
resnet50 = torch.load('model.pth')
方法三:选择保存、加载网络中的一部分参数或者保存额外的参数
保存
save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
torch.save({
'session': args.session,
'epoch': epoch + 1,
'model': fasterRCNN.module.state_dict() if args.mGPUs else fasterRCNN.state_dict(),
'optimizer': optimizer.state_dict(),
'pooling_mode': cfg.POOLING_MODE,
'class_agnostic': args.class_agnostic,
}, save_name)
加载
load_name = os.path.join(output_dir,
'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
checkpoint = torch.load(load_name)
args.session = checkpoint['session']
args.start_epoch = checkpoint['epoch']
fasterRCNN.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = optimizer.param_groups[0]['lr']
cfg.POOLING_MODE = checkpoint['pooling_mode']