(1)Pytorch模型保存和加载

Pytorch的模型保存和加载

官方模型加载

# 官方模型加载
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=True)
vgg16 = torchvision.models.vgg16(pretrained=False) 
vgg16 = torchvision.models.vgg16() # 默认pretrained=False

模型保存

# 模型的保存
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
# 模型保存的两种方式
# 保存方法一(模型和参数一起保存)
torch.save(vgg16, "vgg16_method1.pth")
# 保存方法二(仅保存参数,官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")   # 将模型状态(参数)保存为字典形式

模型调用

# 保存模型的调用
import torch
import torchvision

# 加载方式一
model1 = torch.load("vgg16_method1.pth")
print("------model1------")
print(model1)

# 加载方式二
model2 = torch.load("vgg16_method2.pth")
print("------model2------")
print(model2)	#⭐注意看这个输出,真的只有参数噢,不保存网络结构
 
# 若要使用方式二加载模型参数和结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict((torch.load("vgg16_method2.pth")))
print(vgg16)

参考资料: 深度学习pytorch:VGG网络模型的使用、修改及保存、添加线性层、修改网络输出_学好迁移Learning的博客

你可能感兴趣的:(pytorch基础知识积累,pytorch,python,深度学习)