PyTorch入门学习(十六):网络模型的保存与读取

目录

一、使用预训练模型

二、修改预训练模型的层内容

三、保存和加载模型


一、使用预训练模型

PyTorch提供了许多预训练的深度学习模型,如VGG、ResNet、Inception等。这些模型在大规模数据集上进行了训练,并在各种计算机视觉任务中表现出色。可以轻松地使用这些模型,甚至可以进行一些自定义修改以适应特定的任务。

在下面的代码示例中,展示了如何使用PyTorch中的VGG16模型,并在其基础上进行自定义修改以适用于CIFAR-10数据集:

import torchvision
from torch import nn

# 创建一个未初始化的VGG16模型
vgg16_false = torchvision.models.vgg16(weights=None)

# 创建一个包含预训练权重的VGG16模型
vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)

# 对VGG16模型进行自定义修改,适用于CIFAR-10数据集
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

# 输出模型结构
print(vgg16_true)

通过上述代码,可以加载VGG16模型,并在其classifier部分添加一个全连接层,将输出维度从1000修改为10,以适应CIFAR-10数据集的分类任务。使能够利用预训练的特征提取部分,同时对输出层进行自定义。

二、修改预训练模型的层内容

有时候,需要修改预训练模型的内部层内容,以适应不同的任务。下面的示例演示了如何修改VGG16模型的输出层,以适应新的分类任务:

import torchvision
from torch import nn

# 创建一个未初始化的VGG16模型
vgg16_false = torchvision.models.vgg16(weights=None)

# 修改VGG16模型的输出层
vgg16_false.classifier[6] = nn.Linear(4096, 10)

# 输出模型结构
print(vgg16_false)

在这个示例中,直接修改了VGG16模型的第7个全连接层(索引为6),将其输出维度从4096修改为10。这允许在不改变模型的其他部分的情况下,为新任务创建一个适用的模型。

三、保存和加载模型

一旦自定义了模型,可能需要将其保存到磁盘以备将来使用,或者加载预训练的模型以进行进一步的微调。PyTorch提供了保存和加载模型的功能。

保存模型

torch.save(vgg16_false.state_dict(), 'custom_vgg16.pth')

使用上述代码,将自定义的VGG16模型的权重保存到名为"custom_vgg16.pth"的文件中。

加载模型

loaded_model = torchvision.models.vgg16(weights=None)  # 创建一个未初始化的模型
loaded_model.load_state_dict(torch.load('custom_vgg16.pth'))  # 加载保存的权重

通过上述代码,首先创建一个未初始化的VGG16模型,然后加载保存的权重,这样就可以使用已保存的自定义模型了。

完整代码如下:

import torchvision
from torch import nn
from torchvision.models import VGG16_Weights

# train_data = torchvision.datasets.ImageNet("D:\\Python_Project\\pytorch\\data_image_net",split="train",download=True,transform=torchvision.transforms.ToTensor())

# 错误原因:参数pretrained自0.13起已弃用,将在0.15后删除,要改用“weights”。
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)

# print(vgg16_true)

# 要想用于 CIFAR10 数据集, 可以在网络下面多加一行,转成10分类的输出,这样输出的结果,跟下面的不一样,位置不一样
# vgg16_true.add_module('add_Linear',nn.Linear(1000,10))
# print(vgg16_true)

vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# 层级不同
# 如何利用现有的网络,改变结构
print(vgg16_true)

# 上面是添加层,下面是如何修改VGG里面的层内容
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)  # 中括号里的内容,是网络输出结果自带的索引,套进这种格式,就可以直接修改那一层的内容
print(vgg16_false)

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

你可能感兴趣的:(PyTorch,学习)