%matplotlib inline
#将所有的变量直接显示,而不用显式的调用print
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
英语渣,pytorch不熟悉,翻了个人记录用,侵了请联系我删。。。。
CSDN看得麻烦的,可以移步
https://colab.research.google.com/drive/17WjYQAK8ynPIhIw3E5iHd7Z3BibLJMjf
作者: 马修Inkawhich
译者: 团长sama
本文档提供解决方案,以各种关于PyTorch模型的保存和加载使用情况。随时阅读整个文档,或者只是跳到你需要一个期望的使用情况下的代码。
保存和加载模型需要了解以下三个核心功能:
torch.save:将序列化对象保存到磁盘。此功能使用Python的pickle程序进行序列化。模型,张量,以及各类对象的字典可以使用该功能进行保存。
[torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch load#torch.load):使用pickle将pickle文件对象反序列化到内存。该方法也方便设备加载数据(见 Saving & Loading Model Across Devices )。
torch.nn.Module.load_state_dict :使用反序列化的 state_dict 加载模型的参数字典。有关state_dict 参见 What is a state_dict?.
内容:
state_dict
?).torch.nn.DataParallel
模型)state_dict
?在PyTorch中,torch.nn.Module
模型中可学习的参数(即重量和偏置)都包含在模型参数中(使用model.parameters()
访问)。state_dict 是个将每层参数映射到对应的参数张量的python字典对象(OrderedDict)。**注意:state_dict的条目仅包括带有可学习参数的层(卷积层,线性层等)和registered buffers(BN层的mean等)。优化器对象(torch.optim
)也有state_dict ,它包含有关该优化器状态信息,以及所使用的超参数。
由于state_dict对象是OrderedDict,它们可以方便地保存,更新,修改和恢复,方便PyTorch模型和优化器添加了大量的模块化。
让我们来看看分类器训练教程中模型的state_dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140064166433992, 140064166434136, 140064166434208, 140064166434280, 140064166434352, 140064166434424, 140064166434496, 140064166434568, 140064166434640, 140064166434712]}]
state_dict
(推荐)!pwd
/content
PATH="/content/test.pth"
保存:
torch.save(model.state_dict(), PATH)
载入:
model = TheModelClass()
print(torch.load(PATH).keys())
model.load_state_dict(torch.load(PATH))
model.eval()
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
TheModelClass(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
当保存模型用于推理时,仅需要保存经过训练的模型参数就可以了.使用torch.save()
方法保存模型的state_dict
具有最高的灵活性,方便以后恢复模型,所以这里推荐使用torch.save()
方法保存模型参数.
通常我们约定使用.pt
或者.pth
作为模型参数文件的后缀名.
记住,在进行推理之前,务必使用model.eval()
方法将dropout和BN层设置为推理模式,否则将得到和训练不一致的推理结果.
NOTE
注意load_state_dict()
函数的形参是一个字典对象,而不是权重文件的地址.这意味你必须在使用load_state_dict()
将保存的state_dict
反序列化传进去,而不能直接传权重文件的路径
保存:
torch.save(model, PATH)
/usr/local/lib/python3.6/dist-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type TheModelClass. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
/usr/local/lib/python3.6/dist-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Conv2d. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
/usr/local/lib/python3.6/dist-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type MaxPool2d. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
/usr/local/lib/python3.6/dist-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
载入:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
TheModelClass(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
这种保存/加载模型的方式使用最直观的语法和最少的代码.保存模型使用的是python的pickle模块.该方法的缺点就是在使用确切的目录结构保存模型,且序列化数据是绑定到特定的类上.这样做的原因时pickle不会保存模型类本身.相反,它保存了包含类文件的路径,以方便在加载时使用.因此,当其他项目使用和重构时,代码可能以各种方式中断.
通常我们约定使用.pt
或者.pth
作为模型参数文件的后缀名.
记住,在进行推理之前,务必使用model.eval()
方法将dropout和BN层设置为推理模式,否则将得到和训练不一致的推理结果.
保存:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载:
model = TheModelClass()
optimizer = TheOptimizerClass()
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
当保存通用检查点时,无论是用于推理还是恢复训练,需要保存更多的信息而不仅仅是模型的state_dict
.保存优化器的state_dict
也很重要,因为里面保存着随着模型训练而更新的缓冲区和参数.其他的项目,例如中断训练时的epoch,最近的训练loss,外部的torch.nn.Embedding
层等也需要保存.
一般我们将这些信息组织成字典,使用torch.save()
将字典序列化之后进行保存.
通常我们约定使用.tar
作为文件扩展名来保存这些检查点.
加载这些项目时,首先要实例化模型和优化器,然后使用torch.load()
方法来载入字典.
记住,在进行推理之前,务必使用model.eval()
方法将dropout和BN层设置为推理模式,否则将得到和训练不一致的推理结果.若想接着训练,调用model.train()
来确保这些层处于训练模式.
保存:
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
载入:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
当保存诸如GAN,Seq2Seq,集成模型等由多个torch.nn.Modules
组成的模型时,可以使用和上一章节相同的方法.
保存:
torch.save(modelA.state_dict(), PATH)
载入:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
部分加载模型或者加载部分模型时迁移学习和训练新的复杂模型时的常见手段.利用已训练过的参数,即便只有一些可用,也有助于热启动模型的训练过程,并帮助模型比从头开始训练收敛得更快.
无论你是从缺少一些键的字典还是多出一些键的字典加载模型state_dict
,你可以通过设置load_state_dict()
方法中的strict
参数为False
来忽略不匹配的键.
若是你想将某层的参数加载到另一层,但是一些键不匹配,只需更改要加载的state_dict
中参数键的名称,以匹配到你要加载到的模型键中就行了.
保存:
torch.save(model.state_dict(), PATH)
加载:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
当加载在GPU上训练和保存的模型权重到GPU上时,实例化模型之后调用model.to(torch.device('cuda'))
就行,注意同时要保证.to(torch.device('cuda'))
方法应用于所有的模型输入上.my_tensor.to(device)
将返回一个新的my_tensor
拷贝到GPU,而不是覆写my_tensor
.因此,需要手动覆写张量my_tensor = my_tensor.to(torch.device('cuda'))
保存:
torch.save(model.state_dict(), PATH)
载入:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
当加载一个训练和保存都在CPU上的模型权重到GPU上时,将torch.load()
函数中map_location
参数设置成cuda:device_id
,可以将模型加载到一个指定的GPU上.注意同时要保证.to(torch.device('cuda'))
方法应用于所有的模型输入上.my_tensor.to(device)
将返回一个新的my_tensor
拷贝到GPU,而不是覆写my_tensor
.因此,需要手动覆写张量my_tensor = my_tensor.to(torch.device('cuda'))
torch.nn.DataParallel
模型保存:
torch.save(model.module.state_dict(), PATH)
载入:
# 跟上一章一样
torch.nn.DataParallel
是个支持GPU并行的模型包装器.使用model.module.state_dict()
函数可以以常规方式保存DataParallel
模型. 这样你可以灵活的以任何方式将模型加载到任何设备上.