PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

训练模型时,在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果,同时也方便我们复现这些模型,用于之后的研究。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping,这个模块就是用来解决上述的模型保存与重载问题

一、保存与重载模块

若希望保存/加载模型model的参数,而不保存/加载模型的结构,可以通过如下代码

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第1张图片

其中state_dict是torch中的一个字典对象,将每一层与该层的对应参数张量建立映射关系

若希望同时保存/加载模型model的参数以及模型结构,而不保存/加载模型的结构,可以通过如下代码

为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。超参数之一就是训练周期(epoch),训练周期如果取值过小可能会导致欠拟合,取值过大可能会导致过拟合。为了避免训练周期设置不合适影响模型效果,EarlyStopping应运而生。EarlyStopping解决epoch需要手动设定的问题,也可以认为是一种避免网络发生过拟合的正则化方法 

EarlyStopping的原理可以大致分为三个部分:

将原数据分为训练集和验证集;

只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;

将停止之后的权重作为网络的最终参数

初始化 early_stopping 对象:

EarlyStopping 对象的初始化包括三个参数,其含义如下:

patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。

verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。

delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型 

 定义一个函数,表示训练函数,希望通过 EarlyStopping 当测试集上的损失值有所下降时,将此时的信息打印出来,并且保存参数。 先创建将要用到的变量,以及初始化 earlystopping 对象

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第2张图片

之后训练模型并保存损失值,计算每次迭代在训练集和测试集上的损失值得均值,并保存

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第3张图片 

调用 EarlyStopping 中的_call_()模块,判断损失值是否下降,若下降则进行保存,并打印信息

最后调用torch.load()加载最后一次的保存点,即最优模型,并返回模型,以及每轮迭代在训练集、测试集上的损失值的均值

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第4张图片 

二、可视化模块

在模型训练过程中,有时不仅需要保持和加载已经训练好的模型,也需要将训练过程中的训练集损失函数、验证集损失函数、模型计算图(即模型框架图、模型数据流图)等保持下来,供后续分析作图使用

例如,通过损失函数变化情况,可以观察模型是否收敛,通过模型计算图,可以观察数据流动情况等

Tensorboard可以将数据、模型计算图等进行可视化,会自动获取最新的数据信息,将其存入日志文件中,并且会在日志文件中更新信息,运行数据或模型最新的状态。Tensorboard中常用的模块包括如下七类

add_graph():添加网络结构图,将计算图可视化。

add_image()/add_images():添加单个图像数据/批量添加图像数据。

add_figure():添加matplotlib图片。

add_scalar()/add_scalars():添加一个标量/批量添加标量,在机器学习中可用于绘制损失函数。

add_histogram():添加统计分布直方图。

add_pr_curve():添加P-R(精准率-召回率)曲线。  

add_txt():添加文字

Tensorboard的整体用法,参见下图 

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第5张图片 

 TensorBoard中可以使用add_graph()函数保存模型计算图,该函数用于在tensorboard中创建存放网络结构的Graphs,函数及其参数如下:

model(torch.nn.Module) 表示需要可视化的网络模型;

input_to_model(torch.Tensor or list of torch.Tensor)表示模型的输入变量,如果模型输入为多个变量,则用list或元组按顺序传入多个变量即可;

verbose(bool)为开关语句,控制是否在控制台中打印输出网络的图形结构 

例如,有一个数据类型为torch.nn.Module的变量model,输入的张量为input1和input2,期望返回模型计算图,则可以输入如下代码,即可在SummaryWriter的日志文件夹中保存数据流图

 PyTorch中SummaryWriter的输出文件夹一般为runs文件,保存的日志文件不可以直接双击打开,需要在cmd命令窗口中将目录导航到runs文件夹的上一级目录,并输入tensorboard –logdir runs即可打开日志文件,打开后复制链接到浏览器中,即可打开保存的模型计算图或数据变量等 

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第6张图片

TensorBoard中可以使用add_scalar()/add_scalars()函数保存一个或在一张图中保存多个常量,如训练损失函数值、测试损失函数值、或将训练损失函数值和测试损失函数值保存在一张图中。

add_scalar()函数及参数如下:

  

tag(string)为数据标识符;

scalar_value(float or string)为标量值,即希望保存的数值;

global_step(int)为全局步长值,可理解为x轴坐标 

 add_scalars()函数及参数如下:

main_tag(string)为主标识符,即tag的父级名称;

tag_scalar_dict(dict)为保存tag及tag对应的值的字典类型数据;

global_step(int)为全局步长值,可理解为x轴坐标。 

add_scalars()可以批量添加标量,例如,绘制y=xsinx、y=xcosx、y=tanx的图像,可以输入如下代码,保存的日志文件打开方式与上文所述相同

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第7张图片 

 PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)_第8张图片

创作不易 觉得有帮助请点赞关注收藏~~~ 

你可能感兴趣的:(PyTorch基础,pytorch,深度学习,人工智能)