Pytorch学习笔记:加载预训练模型

文章目录

  • 前言
  • 摘要:
    • 1.加载模型
    • 2.修改加载模型
    • 3.加载预训练模型
    • 4.保存模型
  • 回顾总结


前言

torchvision是官方经常需要的包,包括:
torchvision.datasets:预定义的训练集(比如MNIST、CIFAR10等);
torchvision.models:包含预定义好的经典网络结构(比如AlexNet、VGG、ResNet等),
torchvision.transforms:数据增强的方法

模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models
官方文档:https://pytorch.org/docs/master/torchvision/models.html


摘要:

本文记录的是读取已有网络结构和添加预训练模型


1.加载模型

代码如下:

import torchvision.models as models
resnet50 = models.resnet50(pretrained=True)
如果不需要采用torchvision预训练模型参数来初始化,将pretrained设置为False:
resnet50 = models.resnet50(pretrained=False)

2.修改加载模型

以resnet为例,默认的是ImageNet的1000类,比如我们要做二分类,分类猫和狗:

import torch
import torch.nn as nn
resnet.fc = nn.Linear(2048, 2) 

此处复习了原模型中最后的全连接层。

3.加载预训练模型

在实际使用时,通常都会对预训练网络进行修改,那么预训练的参数就不能完全的使用,对两者进行比对,选择相同的参数加载进来

#加载model,model是自己定义好的模型
resnet50 = models.resnet50(pretrained=True) 
model =Net(...) 
 
#读取参数 预训练参数和当前网络参数
pretrained_dict =resnet50.state_dict() 
model_dict = model.state_dict() 
 
#将pretrained_dict里不属于model_dict的键剔除掉 
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict} 
 
# 更新现有的model_dict 
model_dict.update(pretrained_dict) 
 
# 加载我们真正需要的state_dict 
model.load_state_dict(model_dict)  

4.保存模型

方法1:只保存参数,模型自己调用,以resnet为例:

#只保存参数
torch.save(resnet50.state_dict(),'ckp/model.pth')  
#先导入模型结构,再调用保存的参数
resnet=resnet50(pretrained=True)
resnet.load_state_dict(torch.load('ckp/model.pth'))

方法2:模型和参数全部保存:

#保存
torch.save (model, PATH)
#恢复
model = torch.load(PATH)

回顾总结

1.如何调用官方模型
2.如何改写网络
3.如何加载预训练网络:

  1. 读取网络参数
  2. 剔除预训练网络中的不属于当前网络的参数
  3. 更新预训练模型的网络参数
  4. 加载更新的网络参数

4.保存和恢复模型的两种方法

参考:

文章1

你可能感兴趣的:(Pytorch学习笔记,pytorch)