Pytorch提供了两种方式进行保存模型。
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth") # 保存模型结构及参数
# 保存方式2:模型参数,保存成字典的形式(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 陷阱1:方式1保存模型,陷阱在加载处
class Model(nn.Module):
def __init__(self) -> None:
super().__init__() # 初始化父类属性
self.model1 = Sequential(
Conv2d(3, 32, 5, stride=1, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
)
def forward(self, x):
x = self.model1(x)
return x
model = Model()
torch.save(model, "model_method.pth") # 保存模型结构及参数
Pytorch提供了两种方式进行读取模型。
注意:读取模型时,必须引入该模型结构的class定义,否则加载模型时报错缺少类定义。
import torch
import torchvision.models
from Model import Model # 引入模型类,防止加载自定义模型报错
# 方式1:加载模型
model1 = torch.load("vgg16_method1.pth") # 加载模型结构及参数
print("方式1:\n", model1) # 打印模型网络结构
# 方式2:加载模型
model_data = torch.load("vgg16_method2.pth") # 加载模型参数
print("方式2:\n", model_data) # 打印模型网络参数
vgg16 = torchvision.models.vgg16(pretrained=False) # vgg16网络模型
vgg16.load_state_dict(model_data) # 将模型参数加载到模型里
# 陷阱1:导入模型时报错缺少类定义(AttributeError)
# 解决方法:在当前文件加载import该类 from Model import Model Model.py文件里定义了Model类
model = torch.load("model_method.pth") # 加载模型结构及参数
print("陷阱1:\n", model)
输出:
方式1:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
方式2:
OrderedDict([('features.0.weight', tensor([[[[ 3.7737e-04, 4.1346e-02, 6.0702e-02],
[ 7.0125e-02, 3.7126e-02, -7.6289e-02],
[ 1.2145e-01, 4.2173e-02, -1.1606e-01]],
[[-2.3715e-02, 1.9658e-02, -7.4128e-02],
[-2.9713e-02, 3.6599e-03, 9.9301e-03],
[-4.9300e-02, 5.1934e-02, 1.0522e-01]],
[[ 1.4076e-02, 5.1264e-02, -5.4800e-02],
[-3.5250e-02, 2.0560e-02, -2.7887e-03],
[ 2.2512e-02, 5.9779e-02, 4.9314e-02]]],
[[[-8.1202e-03, -4.0062e-02, -4.1275e-02],
[ 1.3463e-02, -4.1142e-02, 1.1663e-01],
[-1.6806e-02, 7.7193e-02, 5.9772e-02]],
[[-3.7491e-03, 7.0595e-02, 3.9575e-02],
[-1.7332e-01, 5.7054e-02, 1.2022e-01],
[ 1.6720e-02, -1.2557e-02, 8.1462e-02]],
[[ 2.0320e-02, -9.4389e-03, -2.6056e-02],
[-9.8172e-03, 1.4638e-01, -2.9588e-04],
[ 1.9194e-02, -5.7499e-02, 4.5579e-02]]],
[[[ 8.1152e-02, -3.3212e-02, 4.4831e-02],
[-2.5436e-02, -3.9699e-02, -4.9673e-02],
[-2.0726e-02, 1.9308e-02, 1.5040e-02]],
[[ 1.0469e-01, 2.3499e-02, 2.0060e-02],
[-9.3836e-02, -3.8625e-02, -4.0413e-02],
[ 7.2539e-02, 2.8679e-02, 3.7398e-02]],
[[-1.9462e-03, -9.2730e-02, 2.1433e-03],
[-1.2013e-01, 6.4750e-02, 8.3451e-02],
[-8.4348e-02, 5.1198e-02, -1.5884e-01]]],
...,
...,
...,
[-0.0068, 0.0025, 0.0026, ..., -0.0150, -0.0085, -0.0084],
[ 0.0023, -0.0015, -0.0213, ..., 0.0131, -0.0111, -0.0071],
[ 0.0091, -0.0014, -0.0073, ..., -0.0146, 0.0060, 0.0087]])), ('classifier.0.bias', tensor([0., 0., 0., ..., 0., 0., 0.])), ('classifier.3.weight', tensor([[-0.0036, 0.0033, 0.0061, ..., 0.0100, 0.0028, -0.0114],
[-0.0017, -0.0052, 0.0002, ..., 0.0097, 0.0015, 0.0009],
[ 0.0189, -0.0090, 0.0017, ..., -0.0046, 0.0094, -0.0055],
...,
[-0.0081, -0.0144, 0.0065, ..., 0.0009, -0.0081, -0.0141],
[ 0.0085, 0.0051, 0.0092, ..., 0.0080, -0.0117, 0.0045],
[-0.0038, -0.0033, 0.0118, ..., -0.0112, -0.0121, -0.0056]])), ('classifier.3.bias', tensor([0., 0., 0., ..., 0., 0., 0.])), ('classifier.6.weight', tensor([[-0.0070, 0.0144, 0.0028, ..., 0.0072, 0.0221, 0.0056],
[ 0.0203, -0.0066, 0.0003, ..., 0.0057, -0.0002, 0.0077],
[-0.0004, 0.0128, 0.0234, ..., 0.0073, 0.0079, 0.0003],
...,
[-0.0023, 0.0004, -0.0097, ..., 0.0037, -0.0093, 0.0014],
[-0.0048, -0.0078, -0.0077, ..., 0.0131, -0.0044, 0.0071],
[-0.0050, -0.0099, -0.0006, ..., -0.0062, -0.0243, -0.0062]])), ('classifier.6.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))])
陷阱1:
Model(
(model1): Sequential(
(0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)