现有网络模型的使用及修改(VGG16为例)

VGG16

修改默认路径

import os
os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16'  # 下载位置

太大了(140多G)不提供直接下载

train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
                                          , transform=torchvision.transforms.ToTensor())

是否预训练

不预训练:采用随机参数
预训练:采用训练好的参数

第一次

现有网络模型的使用及修改(VGG16为例)_第1张图片
现有网络模型的使用及修改(VGG16为例)_第2张图片

第二次

现有网络模型的使用及修改(VGG16为例)_第3张图片
在这里插入图片描述

vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')  # or weights='IMAGENET1K_V1'

完整代码

import torchvision
import os

os.environ['TORCH_HOME'] = r'D:\Pytorch\pythonProject\vgg16'  # 下载位置

# train_set = torchvision.datasets.ImageNet(root='./data_image_net', split='train', download=True
#                                           , transform=torchvision.transforms.ToTensor())

vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')  # or weights='IMAGENET1K_V1'
print(vgg16_true)

现有网络模型的使用及修改(VGG16为例)_第4张图片

加一层线性层-nn.Linear

vgg16_true.add_module('add_linear', nn.Linear(1000, 10))

现有网络模型的使用及修改(VGG16为例)_第5张图片

如果想加到classifier里面

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

现有网络模型的使用及修改(VGG16为例)_第6张图片

修改神经网络某层

vgg16_false.classifier[6] = nn.Linear(4096, 10)

改之前

print(vgg16_false)

现有网络模型的使用及修改(VGG16为例)_第7张图片

改之后

现有网络模型的使用及修改(VGG16为例)_第8张图片

你可能感兴趣的:(python,神经网络,pytorch,深度学习)