PyTorch有一些现有网络模型,利用现有网络模型来解决分类问题,是一种非常快捷的方法。但是,我们的数据集类别数往往与现有网络模型输出层的神经元个数是不相同的,例如VGG-16的输出神经元个数为1000,因此我们常常需要对现有模型的输出层进行修改,修改方式有以下三种:
import torchvision.datasets
from torch import nn
train_data =torchvision.datasets.CIFAR10("../dataset",train = True,download=True,transform=torchvision.transforms.ToTensor())
vgg_16_false = torchvision.models.vgg16(pretrained=False,progress=True)
#方法1
vgg_16_false.add_module("add_linear",nn.Linear(1000,10))
#方法2
vgg_16_false.classifier.add_module("add_linear",nn.Linear(1000,10))
#方法3
vgg_16_false.classifier[6] = nn.Linear(4096,10)
print(vgg_16_false)
修改前:
VGG(
...
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)Process finished with exit code 0
方法1:在原网络中加入一个层,该层的名称为add_linear,是一个线性层:nn.Linear(1000,10)
VGG(
...
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
(add_linear): Linear(in_features=1000, out_features=10, bias=True)
)Process finished with exit code 0
方法2: 在子网络classifer中,加入一个层,该层的名称为add_linear,是一个线性层:nn.Linear(1000,10)
VGG(
...
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
(add_linear): Linear(in_features=1000, out_features=10, bias=True)
)
)Process finished with exit code 0
方法3:对子网络classifier的第(6)层进行更改,改为线性层 nn.Linear(1000,10)
VGG(
...
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=10, bias=True)
)
)Process finished with exit code 0
感谢B站UP主【我是土堆】老师~