【学习笔记】【Pytorch】14、网络模型的保存与读取

【学习笔记】【Pytorch】14、网络模型的保存与读取

  • 一、网络模型的保存
  • 二、网络模型的读取

一、网络模型的保存

Pytorch提供了两种方式进行保存模型。

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")  # 保存模型结构及参数

# 保存方式2:模型参数,保存成字典的形式(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱1:方式1保存模型,陷阱在加载处
class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()  # 初始化父类属性
        self.model1 = Sequential(
            Conv2d(3, 32, 5, stride=1, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
        )

    def forward(self, x):
        x = self.model1(x)
        return x
model = Model()
torch.save(model, "model_method.pth")  # 保存模型结构及参数

二、网络模型的读取

Pytorch提供了两种方式进行读取模型。
注意:读取模型时,必须引入该模型结构的class定义,否则加载模型时报错缺少类定义。

import torch
import torchvision.models
from Model import Model   # 引入模型类,防止加载自定义模型报错

# 方式1:加载模型
model1 = torch.load("vgg16_method1.pth")  # 加载模型结构及参数
print("方式1:\n", model1)  # 打印模型网络结构

# 方式2:加载模型
model_data = torch.load("vgg16_method2.pth")  # 加载模型参数
print("方式2:\n", model_data)  # 打印模型网络参数
vgg16 = torchvision.models.vgg16(pretrained=False)  # vgg16网络模型
vgg16.load_state_dict(model_data)  # 将模型参数加载到模型里

# 陷阱1:导入模型时报错缺少类定义(AttributeError)
# 解决方法:在当前文件加载import该类  from Model import Model   Model.py文件里定义了Model类
model = torch.load("model_method.pth")  # 加载模型结构及参数
print("陷阱1:\n", model)

输出

方式1:
 VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

方式2:
 OrderedDict([('features.0.weight', tensor([[[[ 3.7737e-04,  4.1346e-02,  6.0702e-02],
          [ 7.0125e-02,  3.7126e-02, -7.6289e-02],
          [ 1.2145e-01,  4.2173e-02, -1.1606e-01]],

         [[-2.3715e-02,  1.9658e-02, -7.4128e-02],
          [-2.9713e-02,  3.6599e-03,  9.9301e-03],
          [-4.9300e-02,  5.1934e-02,  1.0522e-01]],

         [[ 1.4076e-02,  5.1264e-02, -5.4800e-02],
          [-3.5250e-02,  2.0560e-02, -2.7887e-03],
          [ 2.2512e-02,  5.9779e-02,  4.9314e-02]]],


        [[[-8.1202e-03, -4.0062e-02, -4.1275e-02],
          [ 1.3463e-02, -4.1142e-02,  1.1663e-01],
          [-1.6806e-02,  7.7193e-02,  5.9772e-02]],

         [[-3.7491e-03,  7.0595e-02,  3.9575e-02],
          [-1.7332e-01,  5.7054e-02,  1.2022e-01],
          [ 1.6720e-02, -1.2557e-02,  8.1462e-02]],

         [[ 2.0320e-02, -9.4389e-03, -2.6056e-02],
          [-9.8172e-03,  1.4638e-01, -2.9588e-04],
          [ 1.9194e-02, -5.7499e-02,  4.5579e-02]]],


        [[[ 8.1152e-02, -3.3212e-02,  4.4831e-02],
          [-2.5436e-02, -3.9699e-02, -4.9673e-02],
          [-2.0726e-02,  1.9308e-02,  1.5040e-02]],

         [[ 1.0469e-01,  2.3499e-02,  2.0060e-02],
          [-9.3836e-02, -3.8625e-02, -4.0413e-02],
          [ 7.2539e-02,  2.8679e-02,  3.7398e-02]],

         [[-1.9462e-03, -9.2730e-02,  2.1433e-03],
          [-1.2013e-01,  6.4750e-02,  8.3451e-02],
          [-8.4348e-02,  5.1198e-02, -1.5884e-01]]],


        ...,
        ...,
        ...,
        [-0.0068,  0.0025,  0.0026,  ..., -0.0150, -0.0085, -0.0084],
        [ 0.0023, -0.0015, -0.0213,  ...,  0.0131, -0.0111, -0.0071],
        [ 0.0091, -0.0014, -0.0073,  ..., -0.0146,  0.0060,  0.0087]])), ('classifier.0.bias', tensor([0., 0., 0.,  ..., 0., 0., 0.])), ('classifier.3.weight', tensor([[-0.0036,  0.0033,  0.0061,  ...,  0.0100,  0.0028, -0.0114],
        [-0.0017, -0.0052,  0.0002,  ...,  0.0097,  0.0015,  0.0009],
        [ 0.0189, -0.0090,  0.0017,  ..., -0.0046,  0.0094, -0.0055],
        ...,
        [-0.0081, -0.0144,  0.0065,  ...,  0.0009, -0.0081, -0.0141],
        [ 0.0085,  0.0051,  0.0092,  ...,  0.0080, -0.0117,  0.0045],
        [-0.0038, -0.0033,  0.0118,  ..., -0.0112, -0.0121, -0.0056]])), ('classifier.3.bias', tensor([0., 0., 0.,  ..., 0., 0., 0.])), ('classifier.6.weight', tensor([[-0.0070,  0.0144,  0.0028,  ...,  0.0072,  0.0221,  0.0056],
        [ 0.0203, -0.0066,  0.0003,  ...,  0.0057, -0.0002,  0.0077],
        [-0.0004,  0.0128,  0.0234,  ...,  0.0073,  0.0079,  0.0003],
        ...,
        [-0.0023,  0.0004, -0.0097,  ...,  0.0037, -0.0093,  0.0014],
        [-0.0048, -0.0078, -0.0077,  ...,  0.0131, -0.0044,  0.0071],
        [-0.0050, -0.0099, -0.0006,  ..., -0.0062, -0.0243, -0.0062]])), ('classifier.6.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))])

陷阱1:
 Model(
  (model1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)

你可能感兴趣的:(Pytorch,pytorch,学习,网络)