pytorch学习笔记(七)——pytorch中现有网络模型的使用、修改、模型的保存、加载

目录

  • 一、pytorch中现有网络模型的使用、修改
  • 二、模型的保存和加载
    • 1. 模型的保存
    • 2.模型的加载

一、pytorch中现有网络模型的使用、修改

  1. 位于torchvision.models

  2. 使用vgg模型为例,采用的数据集是ImageNet,而ImageNet数据集使用前提需要有scipy包
    pip install scipy

    注意:ImageNet光训练集就有147.9G,而且不再能公开访问了

  3. pytorch中使用现有网络模型以及修改现有的网络模型代码示例

import torchvision

# train_data = torchvision.datasets.ImageNet("../data_image_net", split="train", download=True,
#                                            transform=torchvision.transforms.ToTensor())
from torch import nn

"""
理解:
1. pretrained=False时,相当于使用pytorch中现有的网络模型,其中各层的参数采用默认的
2. pretrained=True时,相当于使用pytorch中现有的网络模型,但其中各层的参数采用 我们在数据集上训练好的参数
"""

# 1.使用现有的网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

# 2.在现有的网络模型中添加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

# 3.修改现有网络中的某层的参数
vgg16_false.classifier[7] = nn.Linear(4096, 10)

二、模型的保存和加载

1. 模型的保存

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1,保存了网络模型的结构以及其中的参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,把网络模型的参数保存成字典,不再保存网络模型的结构(官方推荐)占的空间小
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱,用方式1保存自己写的神经网络
class MyNeural(nn.Module):
    def __init__(self):
        super(MyNeural, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x


my_neural = MyNeural()
torch.save(my_neural, "my_neural_method1.pth")

2.模型的加载

import torch
import torchvision
from c17_model_save import *

vgg16 = torchvision.models.vgg16(pretrained=False)

# 加载方式1,对应保存方式1
model = torch.load("vgg16_method1.pth")
print(model)

# 加载方式2,对应保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

# 陷阱1,
# 要让该.py文件加载自己定义的神经网络,需要引入自己定义的神经网络的模板类 from c17_model_save import *
model = torch.load("my_neural_method1.pth")
print(model)

你可能感兴趣的:(pytorch,pytorch,深度学习,python)