机器学习——网络模型的保存与读取

加载网络模型,pretrained=False 网络中模型的参数是没有训练的,初始化参数。

import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)

1.有两种保存方式:torch.save()

①保存网络的 结构+参数

torch.save(vgg16, "vgg16_method1.pth")

②方式2【官方推荐】,网络模型的参数保存为字典,不保存结构。空间小

state_dict()返回包含模块整个状态的字典
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

运行后,出现保存的文件。

2.网络的加载方式:torch.load()

对应方式①打印出的是网络模型的结构

# 加载方式1模型
import torch
# 打印出的是网络模型的结构
model1 = torch.load("vgg16_method1.pth")
print(model1)

对应方式②打印出的参数是字典形式

model2 = torch.load("vgg16_method2.pth")
print(model2)

如何恢复成网络模型?

第一步:新建网络模型

vgg16 = torchvision.models.vgg16(pretrained=False)

第二步:加载网络参数模型

model2 = torch.load("vgg16_method2.pth")

第三步:调用load_state_dict,恢复成网络结构

vgg16.load_state_dict(model2)

你可能感兴趣的:(python,开发语言)