PyTorch模型的保存加载以及数据的可视化

文章目录

  • PyTorch模型的保存和加载
    • 模块和张量的序列化和反序列化
    • 模块状态字典的保存和载入
  • PyTorch数据的可视化
    • TensorBoard的使用
  • 总结

PyTorch模型的保存和加载

在深度学习模型的训练过程中,如何周期性地对模型做存档非常重要。

一方面,深度学习模型的训练是一个长期的过程,一般来说,大的模型可能运行数天或者数周,这样可能就会在训练的过程中出现一些问题。由于模型一般在运行时保存在计算机的内存或者显存中,一旦出现问题可能会导致模型训练结果的丢失。另一方面,对于训练好的模型,经常要对实际的数据进行预测,这就要求训练好的模型权重能以一定的格式保存到硬盘中,方便后续使用时直接载入原来的权重。

基于这两点的共同要求,PyTorch提供了很好的机制来进行模型的保存和加载。

模块和张量的序列化和反序列化

由于PyTorch的模块和张量本质上是torch.nn.Module和torch.tensor类的实例,而PyTorch自带了一系列方法可将这些类的实例转换成字符串,所以这些实例可以通过Python序列化方法进行序列化和反序列化。

PyTorch里面集成了Python自带的pickle包对模块和张量进行序列化。张量的序列化过程本质上是把张量的信息,包括数据类型和存储位置以及携带的数据等转换为字符串,而这些字符串随后可以使用Python自带的文件IO函数进行存储。同样也可以通过文件IO函数读取存储的字符串然后将字符串逆向解析成PyTorch的模块和张量。

保存和载入的函数签名如下:

torch.save(obj, f, pickle_modeule=pickle, pickle_protocol=2)
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

torch.save函数传入的第一个参数是PyTorch中可以被序列化的对象,包括模型和张量等。第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里。第三个参数是默认的,传入的是序列化的库,第四个参数是pickle协议,即如何把对象转换成字符串的规范,其协议版本有0~4版本。

troch.load函数在给定序列化后的文件路径以后,就能输出PyTorch的对象。其第一个参数是文件路径,第二个参数是张量存储位置的映射,第三个参数和torch.save中的一致,最后一个参数用来指定传给pickle_module.load的参数。

模块状态字典的保存和载入

在PyTorch中一般模型可以由两种保存方式,第一种是直接保存模型的实例,第二种是保存模型的状态字典,一个模型的状态字典包含模型所有参数的名字以及名字对应的张量。通过调用state_lict方法,可以获取当前模型的状态字典。

如下代码所示:

lm = LinearModel(5) # 定义线性模型
lm.state_dict() # 获取状态字典
print(lm.state_dict())
t = lm.state_dict # 保存状态字典


lm = LinearModel(5) # 重新定义线性模型
lm.state_dict() # 新的状态字典
print(lm.state_dict())

lm.load_state_dict(t)# 载入原来的状态字典
print(lm.state_dict())

得到的结果如下所示:

OrderedDict([('weight', tensor([[ 0.5843],
        [-1.0206],
        [-0.7556],
        [ 1.3406],
        [ 0.1169]])), ('bias', tensor([-0.2059]))])
OrderedDict([('weight', tensor([[-0.0865],
        [ 0.1113],
        [ 0.9502],
        [-2.3019],
        [ 0.2588]])), ('bias', tensor([-0.0324]))])
OrderedDict([('weight', tensor([[ 0.5843],
        [-1.0206],
        [-0.7556],
        [ 1.3406],
        [ 0.1169]])), ('bias', tensor([-0.2059]))])

可以看到线性回归模型实例反悔了OrderedDict的对象,即顺序字典,其中有两个键值对,分别对应着权重和偏置的张量。获取新的状态字典和原来的不同,当通过load_state_dict方法传入旧的状态字典让模型载入参数后,发现模型的参数更新为原来的模型参数。

一般来说推荐使用state_dict方法获取状态字典,然后保存该张量字典来保存模型,这样可以最大限度地减小代码对PyTorch版本的依赖性。

PyTorch数据的可视化

TensorBoard的使用

TensorBoard是一个数据可视化工具,能够直观地显示深度学习过程中张量的变化,从这个变化中就可以很容易地了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理、张量分量的分布是否在训练中发生变化以及输出训练过程中的图片等等。

这里还是使用博士顿地区房价数据的线性回归模型来举例:

from sklearn.datasets import load_boston
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn

class LinearModel(nn.Module):
    def __init__(self, ndim):
        super(LinearModel, self).__init__()
        self.ndim = ndim
        self.weight = nn.Parameter(torch.randn(ndim, 1))
        self.bias = nn.Parameter(torch.randn(1))
        
    def forward(self, x):
        
        return x.mm(self.weight) + self.bias

boston = load_boston()
lm = LinearModel(13)
criterion = nn.MSELoss()
optim = torch.optim.SGD(lm.parameters(), lr = 1e-6) 
data = torch.tensor(boston["data"], requires_grad=True, dtype=torch.float32) 
target = torch.tensor(boston["target"], dtype=torch.float32)
writer = SummaryWriter() # 定义TensorBoard输出类

for step in range(10000):
    predict = lm(data)
    loss = criterion(predict, target)
    writer.add_scalar("Loss/train", loss, step) # 输出损失函数
    writer.add_histogram("Param/weight", lm.weight, step) # 输出权重直方图
    writer.add_histogram("Param/bias", lm.bias, step) # 输出偏置直方图

    if step and step % 1000 == 0:
        print("Loss:{:.3f}".format(loss.item()))

    optim.zero_grad()
    loss.backward()
    optim.step()

这里相比于之前增加了SummaryWriter的构造函数,在构造一个摘要写入器的实例以后,可以调用实例的方法来添加需要写入摘要的张量信息。这里主要写入了一个标量数据和两个直方图数据。

通过运行训练的代码,在运行10000个epoch之后,可以发现在当前目录下多了以文件夹runs,runs下面有一个文件夹,具体的文件夹名字与训练开始时间、用户主机名称有关。

接下来可以运行tensorboard --logdir ./runs命令,发现Tensorboard的服务器已经启动,可以通过浏览器访问http://127.0.0.1:6006,查看Tensorboard网页界面,如图所示:

PyTorch模型的保存加载以及数据的可视化_第1张图片
从图中可以看出TensorBoard的界面可以显示很多值,比如SCALARS、DISTRIBUTIONS、HISTOGRAMS等。在默认情况下,TensorBoard只显示写入数据类型的几个标签,这里主要是add_scalar产生的SCALARS和add_histogram产生的DISTRIBUTIONS和HISTOGRAMS标签。

PyTorch模型的保存加载以及数据的可视化_第2张图片
SCLARS图像主要是损失函数随着训练步数变化的曲线。

PyTorch模型的保存加载以及数据的可视化_第3张图片
DISTRIBUTION主要显示权重值和偏置的最大和最小的边界随着训练步数的变化过程。

PyTorch模型的保存加载以及数据的可视化_第4张图片
HISTOGRAMS主要显示权重和偏置的直方图随着训练步数的变化过程。

除了上面演示的add_scalar和add_histogram方法外,TensorBoard的SummaryWriter还有一系列其他方法添加不同的数据到TensorBoard界面中,包括但不限于可以写入图片信息的add_image、显示准确率召回率曲线的add_pr_curve等等。

总结

PyTorch通过复用Python自带的序列化函数库pickle,同时构建了张量和模块的序列化方法,来实现深度学习模型的保存和载入。深度学习模型也可以很方便地输出和载入当前模型参数的状态字典,该状态字典和模型的分离也方便不同版本PyTorch训练模型之间的相互兼容。

为了能够方便地观察深度学习的中间结果和张量,以及损失函数的变化情况,PyTorch还集成了TensorBoard相关的插件,能够方便地在网页中对深度学习的模型输出的中间张量进行可视化,也方便了用户对深度学习模型的调试和效果评估。

你可能感兴趣的:(深度学习,python,深度学习,pytorch)