PyTorch框架训练的几种模型区别

PyTorch系列文章目录


文章目录

  • PyTorch系列文章目录
  • 前言
  • 一、.pt模型使用介绍
  • 二、.pth模型使用介绍
  • 三、.pth.tar模型使用介绍
    • 保存
    • 加载
  • 总结


前言

在PyTorch中,.pt、.pth和.pth.tar都是用于保存训练好的模型的文件格式,它们之间的主要区别如下:

.pt文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。

.pth文件是PyTorch旧版本中使用的模型文件格式,它只保存了模型参数,没有保存模型结构和其他相关信息。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。

.pth.tar文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。

总的来说,.pt文件是最新的、最全面的模型保存格式,可以保存整个PyTorch模型,包括模型结构、参数、优化器状态等信息。.pth文件只保存了模型参数,而.pth.tar文件则是在.pth基础上加入了一些元数据信息,可以方便地保存和加载整个模型状态。在实际应用中,我们可以根据需要选择适合自己的模型保存格式。


一、.pt模型使用介绍

.pt模型文件是PyTorch框架中保存模型权重的文件格式,其结构包含以下几个部分:
Header:文件开头的一段信息,包含了PyTorch版本、模型结构等元数据信息。
State dictionary:模型的权重数据,以Python的字典形式保存。每个键对应了模型的一个参数名,值则是对应的权重矩阵或向量。
Optimizer state:如果模型使用了优化器,那么这里保存了优化器的状态信息,包括当前的学习率、动量等参数。
Other metadata:保存了一些附加的元数据信息,比如模型训练时使用的超参数、训练数据集的统计信息等。
要解读.pt模型文件的信息,可以使用PyTorch提供的torch.load()函数来加载模型文件,然后可以通过访问字典中的键值对来获取模型的权重和其他信息。例如,可以使用以下代码加载模型文件并查看模型结构和权重:

import torch
model = torch.load('model.pt')
print(model)

该代码会输出模型的结构和权重信息,可以通过访问字典中的键值对来获取具体的权重数值。例如,可以使用以下代码获取模型中名为’conv1.weight’的卷积层权重矩阵:

weights = model['conv1.weight']
print(weights)

这样就可以查看模型文件中保存的权重信息,并进一步用于模型的部署或微调等操作。

二、.pth模型使用介绍

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。

首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
先举最简单的例子:

import torch

model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构。当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”。

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape
state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)
torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

pth文件既可能保存了模型的图结构,也有可能没保存;
加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

实验
脚本如下:

import torch
import torchvision.models as models

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'only_weights.pth')

model_state_dict = torch.load('only_weights.pth')
model1 = models.vgg16() # describe the graph shape
model1.load_state_dict(model_state_dict)
model1.eval()

torch.save(model1, 'whole_model.pth')

model2 = torch.load('whole_model.pth')
model2.eval()

# model3 = torch.load('only_weights.pth')
# model3.eval()    # Error

model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

三、.pth.tar模型使用介绍

由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型。
此模型保存为 .pth.tar文件。
我希望能够首先加载这个模型。到目前为止,我已经能够弄清楚我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎有效,因为 print(model)打印出大量数字和其他值,我认为这些值是权重和偏差的值。
在此之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

如果您有 .pth.tar文件,您可以加载它,从而覆盖已定义模型的参数值。

这意味着保存/加载模型的一般过程如下:
编写您的网络定义(即您的 nn.Module 对象)
以您想要的方式训练或以其他方式更改网络参数
使用 torch.save 保存参数
当您想使用该网络时,请使用 nn.Module 的相同定义对象首先实例化 pytorch 网络
然后使用 torch.load 覆盖网络参数的值

这是一个超短的 mwe:

保存

torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

加载

checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

总结

你可能感兴趣的:(PyTorch系列,深度学习,ONNX,pytorch,深度学习,机器学习)