Pytorch:保存模型与加载模型

本文目录

  • 1 保存模型
    • torch.save()
  • 2 加载模型
    • 2.1 torch.load()
    • 2.2 torch.nn.Module.load_state_dict()
  • 参考文章

1 保存模型

torch.save()

torch.save()将序列化对象保存到磁盘。序列化就是把数据变成可存储或可传输的过程的,只有序列化后的数据才可以写入到磁盘或者通过网。此函数使用Python的 pickle进行序列化。使用此函数可以保存各种对象的模型、张量和字典。

pickle的一些具体使用函数

pickle.dump() #将任意对象转化成bytes,并写入文件中。

import pickle

# 使用pickle模块将数据对象保存到文件
data = {'a': [1, 2.0, 3, 4+6j],
         'b': ('string', u'Unicode string'),
         'c': None}

file = open('data.pkl','wb')	#将任意对象转化成bytes,并写入文件中
pickle.dump(data, file)
file.close()

2 加载模型

2.1 torch.load()

torch.load()使用Python的pickle将被序列化的对象文件反序列化到内存。此函数还可方便设备将数据加载进来(请看 Saving & Loading Model Across Devices).

pickle的一些具体使用函数

pickle.load() #从文件中反序列出对象

import pickle

file = open('data.pkl','rb')
data = pickle.load(file)
file.close()
print(data)

#{'a': [1, 2.0, 3, (4+6j)], 'b': ('string', 'Unicode string'), 'c': None}

2.2 torch.nn.Module.load_state_dict()

在PyTorch中,

  • 模型参数:
    可以使用model.parameters()访问torch.nn.Module模型的可学习参数(即权重和偏置)包含在模型的 parameters 中。torch.nn.Modulestate_dict 是一个Python字典对象,它将具有可学习参数的层(卷积层、线性层等)映射到其参数张量。
  • 优化器参数:
    Optimizer对象torch.optim有一个state_dict,包含关于优化器状态的信息以及使用的超参数。
    因为 state_dict 对象是Python字典,所以可以轻松地保存、更新、修改和恢复它们, 从而为PyTorch模型和优化器添加了大量的模块化。
    torch.nn.Module.load_state_dict()使用反序列化的state_dict加载模型的参数字典。
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化 model
model = TheModelClass()

# 初始化 optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 输出 model 的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Model's state_dict:
# conv1.weight 	 torch.Size([6, 3, 5, 5])
# conv1.bias 	 torch.Size([6])
# conv2.weight 	 torch.Size([16, 6, 5, 5])
# conv2.bias 	 torch.Size([16])
# fc1.weight 	 torch.Size([120, 400])
# fc1.bias 	 torch.Size([120])
# fc2.weight 	 torch.Size([84, 120])
# fc2.bias 	 torch.Size([84])
# fc3.weight 	 torch.Size([10, 84])
# fc3.bias 	 torch.Size([10])

# 输出 optimizer 的 state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
#Optimizer's state_dict:
# state 	 {}
# param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

在保存模型进行推理时,只需保存经过训练的模型的学习参数state_dict即可。使用 torch.save()函数保存模型的state_dict将为以后恢复模型提供最大的灵活性,这就是为什么推荐使用它来保存模型。
常见的PyTorch约定是使用 .pt .pth文件扩展名保存模型。

模型参数的保存

torch.save(model.state_dict(), file_path)

模型参数的加载

#device = torch.device("cuda")
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(file_path))
#model.to(device)

参考文章

PyTorch save and load models 保存和加载模型

你可能感兴趣的:(pytorch,python,pytorch,python,深度学习)