现有网络模型的使用和修改

(联邦学习笔记,资料来源于b站小土堆)

在pytorch中,torchvison.module里有很多现有的模型可以直接使用,包括

  • Alexnet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3
  • GoogLeNet
  • ShuffleNet v2
  • MobileNet v2
  • MobileNet v3
  • ResNext
  • Wide ResNet
  • MNASNet
  • Quantized Models

 官方文档链接:https://pytorch.org/vision/stable/models.html#id2

本次使用VGG模型

VGG模型是2014年ILSVRC竞赛的第二名,第一名是GoogLeNet。但是VGG模型在多个迁移学习任务中的表现要优于googLeNet。而且,从图像中提取CNN特征,VGG模型是首选算法。它的缺点在于,参数量有140M之多,需要更大的存储空间。但是这个模型很有研究价值。

先来看看VGG的特点:

  • 小卷积核。作者将卷积核全部替换为3x3(极少用了1x1);

  • 小池化核。相比AlexNet的3x3的池化核,VGG全部为2x2的池化核;

  • 层数更深特征图更宽。基于前两点外,由于卷积核专注于扩大通道数、池化专注于缩小宽和高,使得模型架构上更深更宽的同时,计算量的增加放缓;

  • 全连接转卷积。网络测试阶段将训练阶段的三个全连接替换为三个卷积,测试重用训练时的参数,使得测试得到的全卷积网络因为没有全连接的限制,因而可以接收任意宽或高为的输入。

百度百科相关介绍链接:

 https://baike.baidu.com/item/VGG%20%E6%A8%A1%E5%9E%8B/22689655?fr=aladdin

VGG模型有两个参数:prtrained和progress

prtraine:如果为True,返回在ImageNet上预先训练过的模型

progress:如果为True,则显示下载到stderr的进度条

现有网络模型的使用和修改_第1张图片

引用现有的vgg模型:


#本次使用VGG16模型
#prtraine:如果为True,返回在ImageNet上预先训练过的模型
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
给现有的模型vgg16添加模型
vgg16_true.add_module("add_linear0",nn.Linear(1000,10))
print(vgg16_true)

  在vgg16的classifier层里添加模型 

vgg16_true.classifier.add_module("add_linear1",nn.Linear(1000,10))
在vgg16的classifier层的第六行改成,Linear(in_features=4096,out_features=10)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

现有网络模型的使用和修改_第2张图片

完整代码展示:

import torchvision


#本次使用VGG16模型
#prtraine:如果为True,返回在ImageNet上预先训练过的模型
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)

train_data = torchvision.datasets.CIFAR10("../fl/data",train=True,transform=torchvision.transforms.ToTensor(),download=True)

#给现有的模型vgg16添加内容
vgg16_true.add_module("add_linear0",nn.Linear(1000,10))
print(vgg16_true)

#在vgg16的classifier层里添加模型

vgg16_true.classifier.add_module("add_linear1",nn.Linear(1000,10))

#在vgg16的classifier层的第六行改成,in_features=4096,out_features=10

vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

你可能感兴趣的:(pytorch,联邦学习,网络模型,pytorch,网络模型)