先有网络模型的使用及修改

文章

    • 先有网络模型的使用
    • 先有网络模型的修改(如何利用现有的网络去改动它的一个结构)
      • 1.添加网络层
      • 2.直接修改网络

先有网络模型的使用

使用示例代码:

import torchvision
from torch import nn

# 加载网络

# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)

# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)

vgg16的使用有两个常用参数,分别是pretrainedprocess

  • pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
  • process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
    通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。

先有网络模型的修改(如何利用现有的网络去改动它的一个结构)

1.添加网络层

示例代码如下:

import torchvision
from torch import nn

# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构

# 1.添加网络层

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))

# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))

print("vgg16_true:\n",vgg16_true)

运行结果:
先有网络模型的使用及修改_第1张图片
先有网络模型的使用及修改_第2张图片

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。

2.直接修改网络

示例代码如下:

import torchvision
from torch import nn

# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)

vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)

# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络

# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())

# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

# 按顺序对网络进行索引,修改最后的线性层 
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)

运行结果:
先有网络模型的使用及修改_第3张图片
先有网络模型的使用及修改_第4张图片

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。

你可能感兴趣的:(pytorch,网络,深度学习,人工智能)