本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:
torch.save
:把序列化的对象保存到硬盘。它利用了 Python 的 pickle
torch.load
:采用 pickle
将反序列化的对象从存储中加载进来。torch.nn.Module.load_state_dict
:采用一个反序列化的 state_dict
加载一个模型的参数字典。PyTorch 中,一个模型(torch.nn.Module
)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters()
)中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm
的 running_mean
)。优化器对象(torch.optim
)同样也是有一个状态字典,包含的优化器状态信息以及使用的超参数。
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。
下面是一个简单的使用例子
# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化模型
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 打印模型的state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 打印优化器的state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
上述代码先是简单定义一个 5 层的 CNN,然后分别打印模型的参数和优化器参数。
输出结果:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
保存的代码:
torch.save(model.state_dict(), PATH) # PATH = 'models/net.pth'
则同级目录models下将会出现net.pth文件,pth文件中的内容就是model的参数名称和值对应的state_dict
加载的代码:
# 需先搭建网络模型model
model = TheModelClass(*args, **kwargs)
# 然后通过下面的语句加载参数
model.load_state_dict(torch.load(PATH)) # PATH = 'models/net.pth'
model.eval()
当需要为inference保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save()
来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
通常会用 .pt
或者 .pth
后缀来保存模型。
记住
torch.save()
,使用了一种新的基于 zipfile
的文件格式。Load仍然保留以旧格式加载文件的能力。如果希望 torch.save()
使用旧格式,请传递 kwarg_use_new_zipfile_serialization = False
。model.eval()
方法来将 dropout
和 batch normalization
层设置为验证模型,防止参数更新。load_state_dict()
方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用torch.load()
,而不是直接model.load_state_dict(PATH)
保存:
torch.save(model, PATH) #'./model.pth'
加载:
# 模型类必须在某处定义
# 加载完整的模型结构和参数信息,在网络较大时加载时间比较长,同时存储空间也比较大
model = torch.load(PATH)
model.eval()
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle
模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle
并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors
后采用都可能出现错误。
示例
来自莫烦Python
import torch
import matplotlib.pyplot as plt
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.3)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 2 ways to save the net
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()
保存优化器参数值和epoch值的主要目的是用于继续训练,保存的流程依旧是先torch.save()
再torch.load_state_dict()
,我们首先定义一个Adam优化器、一个任意的epoch值与net如下:
net = Net()
Adam = optim.Adam(params=net.parameters(), lr=0.001, betas=(0.5, 0.999))
epoch = 96
现在,创建一个字典来保存所有的对象,并用save函数保存这个字典:
all_states = {"net": net.state_dict(), "Adam": Adam.state_dict(), "epoch": epoch}
torch.save(obj=all_states, f="models/all_states.pth")
所有的对象all_states.pth都被保存到models文件夹下了
可以使用load()函数把所有的对象再次提取出来
reload_states = torch.load("models/all_states.pth")
print(reload_states)
保存的示例代码:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载的示例代码:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅要存储模型的参数 model.state_dict
,还有优化器的 optimizer.state_dict
,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch
,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding
层等等。
上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save
方法序列化字典,一般保存的文件后缀名是 .tar
。
加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load
加载对应的 state_dict
。然后通过不同的键来获取对应的数值。
加载完后,根据后续步骤,调用 model.eval()
用于预测,model.train()
用于恢复训练。
保存模型的示例代码:
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
加载模型的示例代码:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
当我们希望保存的是一个包含多个网络模型 torch.nn.Modules
的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict
和对应优化器的 state_dict
。除此之外,还可以继续保存其他相同的信息。
加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar
作为后缀名。
保存模型的示例代码:
torch.save(modelA.state_dict(), PATH)
加载模型的示例代码:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。
这种做法通常是加载预训练模型的部分网络参数作为模型的初始化参数,然后可以加快模型的收敛速度。
加载预训练模型的代码如上述所示,其中设置参数 strict=False
表示忽略不匹配的网络层参数,因为通常我们都不会完全采用和预训练模型完全一样的网络,通常输出层的参数就会不一样。
当然,如果希望加载参数名不一样的参数,可以通过修改加载的模型对应的参数名字,这样参数名字匹配了就可以成功加载。
在GPU上保存模型,在 CPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load()
的时候,设置参数 map_location
,指定采用的设备是 torch.device('cpu')
,这个做法会将张量都重新映射到 CPU 上。
在GPU上保存模型,在 GPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
在 GPU 上训练和加载模型,调用 torch.load()
加载模型后,还需要采用 model.to(torch.device('cuda'))
,将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)
。
在CPU上保存,在GPU上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location
指定设备。然后继续记得调用 model.to(torch.device('cuda'))
。
保存 torch.nn.DataParallel 模型
保存模型的示例代码:
torch.save(model.module.state_dict(), PATH)
torch.nn.DataParallel
是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()
。
加载模型的代码也是一样的,采用 torch.load()
,并可以放到指定的 GPU 显卡上。
# PyTorch中的torchvision里有很多常用的模型,可以直接调用:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
这里是直接调用pytorch中的常用模型
# 导入模型结构
resnet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练参数到resnet18
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))
resnet152 = models.resnet152(pretrained=True)
pretrained_dict = resnet152.state_dict()
"""
加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
也可以直接从官方model_zoo下载:
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
pytorch中state_dict()和load_state_dict()函数配合使用可以实现状态的获取与重载,load()和save()函数配合使用可以实现参数的存储与读取。其中最重要的部分是“字典”的概念,因为参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的。
参考
https://zhuanlan.zhihu.com/p/82038049
https://zhuanlan.zhihu.com/p/94971100