PyTorch提供了三种方式来保存和加载模型,在这三种方式中,加载模型的代码和保存模型的代码必须相匹配,才能保证模型的加载成功。通常情况下,使用第一种方式(保存和加载模型状态字典)更加常见,因为它更轻量且不依赖于特定的模型类。
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
# 保存整个模型
torch.save(model.state_dict(), 'sample_model.pt')
import torch
import torch.nn as nn
# 下载模型参数 并放到模型中
loaded_model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
loaded_model.load_state_dict(torch.load('sample_model.pt'))
print(loaded_model)
显示如下:
Sequential(
(0): Linear(in_features=128, out_features=16, bias=True)
(1): ReLU()
(2): Linear(in_features=16, out_features=1, bias=True)
)
net.state_dict(),在PyTorch中,Module 的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters() 访问)。state_dict 是一个从参数名称隐射到参数 Tesnor 的有序字典对象。只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict 中的条目。
import torch
import torch.nn as nn
net = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
# 保存整个模型,包含模型结构和参数
torch.save(net, 'sample_model.pt')
import torch
import torch.nn as nn
# 加载整个模型,包含模型结构和参数
loaded_model = torch.load('sample_model.pt')
print(loaded_model)
显示如下:
Sequential(
(0): Linear(in_features=128, out_features=16, bias=True)
(1): ReLU()
(2): Linear(in_features=16, out_features=1, bias=True)
)
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
input_sample = torch.randn(16, 128) # 提供一个输入样本作为示例
torch.onnx.export(model, input_sample, 'sample_model.onnx')
import torch
import torch.nn as nn
import onnx
import onnxruntime
loaded_model = onnx.load('sample_model.onnx')
session = onnxruntime.InferenceSession('sample_model.onnx')
print(session)
保存模型函数torch.save
将对象序列化保存到磁盘中,该方法原理是基于python中的pickle
来序列化,各种Models
,tensors
,dictionaries
都可以使用该方法保存。保存的模型文件名可以是.pth
, .pt
, .pkl
。
def save(
obj: object,
f: FILE_LIKE,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True
) -> None:
备注:关于模型的后缀.pt、.pth、.pkl它们并不存在格式上的区别,只是后缀名不同而已。 torch.save()语句保存出来的模型文件没有什么不同。
加载模型函数torch.load
def load(
f: FILE_LIKE,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
weights_only: bool = False,
**pickle_load_args: Any
) -> Any:
torch.device
对象加载模型参数torch.nn.Module.load_state_dict
序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True):
函数作用是“获取优化器当前状态信息字典”,在神经网络中模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:
def state_dict(self, destination=None, prefix='', keep_vars=False):
除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较便捷。
采用仅加载模型参数的方式,指定设备类型进行模型加载,代码如下:
model_path = '/opt/sample_model.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
map_location = torch.device(device)
model.load_state_dict(torch.load(self.model_path, map_location=self.map_location))