现有网路模型的使用及其修改————PyTorch

哔哩大学的PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
的P25讲讲述了如何加载一些现有的pytorch网络模型,或对网络模型中的一些结构进行修改.

首先安装scipy
现有网路模型的使用及其修改————PyTorch_第1张图片

修改代码为:
(其实我听不太懂,看不懂添加修改的啥,详情看开头的视频链接)

import torchvision

# train_data = torchvision.datasets.ImageNet("data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())
# 模型太大,数据不公开,所以没法下载
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)

train_data = torchvision.datasets.CIFAR10("data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
# 发现有添加

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
# 发现有修改

结果展示;
现有网路模型的使用及其修改————PyTorch_第2张图片

你可能感兴趣的:(pytorch,pytorch,深度学习)