torch 模型加载和保存模型

先有一个模型:
my_resnet = MyResNet(*args, **kwargs)

两种加载权重方法:
1.基于推荐保存的方式
保存方式:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

对应的加载方式:

my_resnet.load_state_dict(torch.load("my_resnet.pth"))
  1. 直接torch.load
my_resnet = torch.load("my_resnet.pth")

加载部分预训练模型
PyTorch 中的 torchvision 里已经有很多常用的模型了,可以直接调用:AlexNet 、VGG 、ResNet 、SqueezeNet 、DenseNet

import torchvision.models as models
import torch.utils.model_zoo as model_zoo


pretrained_dict = model_zoo.load_url(model_urls['resnet152'])

model_dict = model.state_dict()

# 将 pretrained_dict 里不属于 model_dict 的键剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 更新现有的 model_dict

model_dict.update(pretrained_dict)

# 加载我们真正需要的 state_dict

model.load_state_dict(model_dict)

ref
https://www.sohu.com/a/137839632_717210

你可能感兴趣的:(torch,pytorch学习)