小土堆-pytorch-神经网络-网络模型的保存和读取13_笔记

保存与读取方式一:

创建2个python空文件,模拟保存和读取
保存:

import torch
import torchvision
vgg16=torchvision.models.vgg16(weights=False)
# 保存方式一: 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")

保存完成后,会出现这样的文件
在这里插入图片描述
读取:

import torch
# 方式一 保存方式一,加载模型
model=torch.load("vgg16_method1.pth")
print(model)

运行结果截图:
小土堆-pytorch-神经网络-网络模型的保存和读取13_笔记_第1张图片

保存与读取方式二:

保存:

# 保存方式二: 以字典类型保存,保存它的参数(官方推荐) 模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

在这里插入图片描述
读取:

# 方式二:加载模型
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

运行结果截图:
小土堆-pytorch-神经网络-网络模型的保存和读取13_笔记_第2张图片
方式一的陷阱:

from torch import nn


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1=nn.Conv2d(3,64,3)
        
    def forward(self,x):
        x=self.conv1(x)
        return x
    
tudui=Tudui()
torch.save(tudui,"tuidui_method1.pth")

错误的加载

# 用原方式一的方式加载
model=torch.load('torch_method.pth')
print(model)

正确的加载

# 导入包
form model_save import *
# 然后正常加载

你可能感兴趣的:(pytorch,pytorch,神经网络,笔记,python,人工智能)