PyTorch学习笔记(7)网络模型的使用与修改

本文使用 pytorch 自带的网络模型,并略作修改,使它能够适应其它的数据集。

文章目录

    • 下载
    • 修改模型
    • 模型的保存与读取


下载

这里以 vgg16 为例子。

pretrained=True 表示下载已经训练好的模型,即包含相关参数,progress=True 表示显示进度条。

import torchvision
from torch import nn

# vgg16_false = torchvision.models.vgg16(pretrained=False)    #未训练好的模型
vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True)

print(vgg16_true) # 输出网络结构

通过调试可以查看模型具体属性:

PyTorch学习笔记(7)网络模型的使用与修改_第1张图片

通过打印,我们可以看到详细的网络结构。然而 vgg 模型是对 1000 个类别进行分类的,我们应该怎样使用呢?

PyTorch学习笔记(7)网络模型的使用与修改_第2张图片

修改模型

假如我们想使用 CIFAR10 数据集,去套用 vgg 的网络,那么就需要在最后加一个线性层 nn.Linear(1000,10),或者将原有的线性层输出改为 10。
代码如下:

import torchvision
from torch import nn

# vgg16_false = torchvision.models.vgg16(pretrained=False)    #未训练好的模型
vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True)

# 如何应用这个网络模型?
# out_features=1000改为10
train_data = torchvision.datasets.CIFAR10("./dataset_CIFAR10",train=True,transform=torchvision.transforms.ToTensor(), download=False)

# 增加一个线性层
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

# 修改线性层
# vgg16_true.classifier[6] = nn.Linear(4096,10)
# print(vgg16_true)

模型的保存与读取

有两种方式,一个是同时保存模型结构参数,一个是只保存参数。
不同的保存方式,对应的加载方式也略有不同。

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
model = torch.load("vgg16_method1.pth")
print(model)

# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
model = torch.load("vgg16_method2.pth")
print(model)

如果只保存了参数,那么在使用时需要加载模型结构:

import torch
import torchvision
from torch import nn

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))

你可能感兴趣的:(机器学习,pytorch,学习,深度学习)