pytorch模型保存与加载

官方文档

模型保存相关的三个核心功能

torch.save: 将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化,使用此模型可以保存如模型、tensor、字典等各种对象。
torch.load: 使用pickle的unpicking功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Moudle.load_state_dict: 使用反序列化函数state_dict来加载模型的参数字典。

状态字典

  在pytorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的parameters中,(使用model.parameters()可以进行访问)。state_dict仅仅是python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层、线性层等)的模型才具有state_dict这一项。优化目标torch.optim也有state_dict属性,它包含有关优化器的状态信息,以及使用的超参数。

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])

Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

保存和加载推断模型

保存/加载state_dict(推荐使用)

保存:

1
torch.save(model.state_dict(), PATH)

加载:

1
2
3
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

  用保存的模型进行推断的时候,只需要保存模型学习到的参数,使用torch.save()函数来保存模型state_dict,所用的资源要少于保存完整模型。在进行推断之前,要调用model.eval()去设置dropout和batch normalization层为评估模式。在传入load_state_dict()函数之前,需要使用torch.load()state_dict进行反序列化。

保存/加载完整模型

保存:

1
torch.save(model, PATH)

加载:

1
2
model = torch.load(PATH)
model.eval()

保存torch.nn.DataParallel模型

保存:

1
2
3
model = TheModelClass(*args, **kwargs)
model = torch.nn.DataParallel(model)
torch.save(model.state_dict(), PATH)

加载:

1
2
3
4
model = TheModelClass(*args, **kwargs)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(PATH))
model.eval()

在加载模型继续训练的时候,加载了两次torch.nn.DataParallel,保存的模型进行推断也需要加载两次才能进行推断。可以通过以下方法将保存的模型转化为非DataParallel模式的模型(所有key的名字前去掉modules)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from collections import OrderedDict
from efficientnet import efficientnet_b0b
import torch.nn as nn
import torch

model_path = 'Result/efficientnet/07-22_13-15-51/1net_params.pkl'
state_dict = torch.load(model_path)
new_state_dict = OrderedDict()

for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v

two_state_dict = OrderedDict()

for k, v in new_state_dict.items():
name = k[7:]
two_state_dict[name] = v

net = efficientnet_b0b((224, 224), num_classes=1852)
net = nn.DataParallel(net)
net.load_state_dict(new_state_dict)

你可能感兴趣的:(pytorch模型保存与加载)