(1)Pytorch模型添加、删除、修改网络层

Pytorch模型添加、删除、修改网络层

import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=True) # 加载预训练网络模型

## 添加层
vgg16.add_module("add_linear_end", nn.Linear(1000, 10))  # 在整个module外面增加一个Linear (1)
vgg16.classifier.add_module("add_linear", nn.Linear(1000, 10)) # (2)

## 修改层
# 方式一:知道输入特征维度
vgg16.classifier[6] = nn.Linear(4096, 10) # (3)

# ⭐方式二:不用自己查,直接代码获得输入特征维度
num_fc = vgg16.classifier[6].in_features #读取输入特征的维度
vgg16.classifier[6] = nn.Linear(num_fc,2) #修改最后一层的输出维度,即分类数 (4)

## 删除层
del vgg16.classifier[6] # (5)
del vgg16.classifier[6] # (6)

输出截图如下

(1)(1)Pytorch模型添加、删除、修改网络层_第1张图片
(2)
(1)Pytorch模型添加、删除、修改网络层_第2张图片

(3)修改了分类器输出特征数(1)Pytorch模型添加、删除、修改网络层_第3张图片

(4)实现的功能和 3 一样,但是这种做法更具有普适性
(1)Pytorch模型添加、删除、修改网络层_第4张图片

(5)
(1)Pytorch模型添加、删除、修改网络层_第5张图片
(6)
(1)Pytorch模型添加、删除、修改网络层_第6张图片

参考资料: 深度学习pytorch:VGG网络模型的使用、修改及保存、添加线性层、修改网络输出_学好迁移Learning的博客

你可能感兴趣的:(pytorch基础知识积累,pytorch,深度学习,python)