PyTorch 现有网络模型的使用及修改

PyTorch 现有网络模型的使用及修改

先用最简单的 VGG分类模型 作为案例

最常用的是VGG16 和 VGG19两个版本

PyTorch 现有网络模型的使用及修改_第1张图片

  • pretrained
    如果为true 下载的网络模型中的参数在ImageNet数据集中已经训练好。
    ​ 如果为False 下载的网络模型中的参数没有训练过。

  • progress
    如果为True 显示下载进度条
    如果为False 不显示下载进度条

由于ImageNet数据集太大,就不下载了。



vgg16 架构

我们先看看看到 vgg_16 的网络架构

import torchvision.datasets
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)    # pretrained=False, 只加载网络模型, 不需要下载, 参数都是默认的
vgg16_true = torchvision.models.vgg16(pretrained=True)      # pretrained=True,  从网络上下载已经训练好的参数

print(vgg16_true)
# print(vgg16_false)

PyTorch 现有网络模型的使用及修改_第2张图片
PyTorch 现有网络模型的使用及修改_第3张图片



添加操作

由上面的结果可知,vgg模型是一个分类模型,分出的类有1000个

当我们只需要分 10 类时,该怎么办呢?

我们可以再加一个线性层, 令in_features=1000, out_features=10.

vgg16_true.add_module('add_linear', nn.Linear(1000, 10))

PyTorch 现有网络模型的使用及修改_第4张图片


如果想将新的线性层加入 (classifier) 中,只需要使用命令:

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10)) 即可

PyTorch 现有网络模型的使用及修改_第5张图片


修改操作

上面是在现有的网络模型中添加一些东西,但是如果我们想修改呢?

比如说修改vgg16_false 中 classifier 中的第七个线性层的参数

vgg16_false.classifier[6] = nn.Linear(4096, 10)

PyTorch 现有网络模型的使用及修改_第6张图片


代码

import torchvision.datasets
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)    # pretrained=False, 只加载网络模型, 不需要下载, 参数都是默认的
vgg16_true = torchvision.models.vgg16(pretrained=True)      # pretrained=True,  从网络上下载已经训练好的参数


"""
 CIFAR10 将数据分成了 10 类, 但是 VGG16 将数据分成了 1000 类, 如何应用这个vgg_16训练好的模型处理 CIFAR10 的数据呢?
 我们可以通过改变现有网络的结构来进行相应的操作
"""
train_data = torchvision.datasets.CIFAR10('dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)


# 我们再加一个线性层, in_features=1000, out_features=10.
print(vgg16_true)
vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

# 将线性层加到 classifier 中
vgg16_true.classifier.add_module('add_linear_to_classifier', nn.Linear(1000, 10))
print(vgg16_true)

# 修改vgg16_false
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

你可能感兴趣的:(pytorch,网络,深度学习)