【pytorch】现有网络模型的使用和修改、保存和加载

文章目录

      • 一、网络模型的使用和修改
      • 二、网络模型的保存和加载
        • (1)模型的保存
        • (2)模型的加载

一、网络模型的使用和修改

vgg16_true.add_module(‘add_linear’,nn.Linear(1000,10))

import torchvision
from torch import nn
from torchvision import models

# train_data = torchvision.datasets.ImageNet("./data_image_net",split='train',download=True,
#                                            transform=torchvision.transforms.ToTensor())

#pytorch
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=models.VGG16_Weights.DEFAULT)
# print("OK")
# print(vgg16_true)

train_data = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)

vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)

二、网络模型的保存和加载

(1)模型的保存

  • torch.save(vgg16_false,“vgg16_method1.pth”)

  • torch.save(vgg16_false.state_dict(),“vgg16_method2.pth”)

import torch
import torchvision
from torch import nn
from torchvision import models

# train_data = torchvision.datasets.ImageNet("./data_image_net",split='train',download=True,
#                                            transform=torchvision.transforms.ToTensor())

#pytorch
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=models.VGG16_Weights.DEFAULT)

#保存方式1
torch.save(vgg16_false,"vgg16_method1.pth")


#保存方式2
torch.save(vgg16_false.state_dict(),"vgg16_method2.pth")

# print("OK")
# print(vgg16_true)

train_data = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)

vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)


【pytorch】现有网络模型的使用和修改、保存和加载_第1张图片

保存之后目录中会出现对应的pth文件,出现之后所保存的模型就可以用于之后的加载操作了。

(2)模型的加载

import torch

#方式1->保存方式1加载模型
import torchvision

model = torch.load("vgg16_method1.pth")
print("vgg16_false1:",model)

#方式2->保存方式2加载模型
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_false.load_state_dict(torch.load("vgg16_method2.pth"))
print("vgg16_false2:",vgg16_false)

显示模型(局部):

【pytorch】现有网络模型的使用和修改、保存和加载_第2张图片

你可能感兴趣的:(pytorch,深度学习,人工智能)