现有网络模型的使用和修改

#False,下载的是网络模型,默认参数
#vgg16_false = torchvision.models.vgg16(pretrained = False)

#True,下载的是网络模型,并且在数据集上面训练好的参数。
#vgg16_true = torchvision.models.vgg16(pretrained = True)

#学习在现有的网络进行修改。vgg16是将数据分为1000类。而数据集CIFAR10只有十类。
#1.给vgg16模in_feature=1000,out_feature = 10;
#2.直接修改,将最后一层改为out_feature = 10;

import torchvision
from torch import nn


vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16_true = torchvision.models.vgg16(pretrained = True)

print(vgg16_true)

train_data = torchvision.datasets.CIFAR10("a",train=True,transform=torchvision.transforms.ToTensor(),
                                         download = True)

#在vgg16的classifier下加一层模型,名叫add_linear,module名,in_feature=1000,out_feature=10
vgg16_true.classifier.add_module('add_lnear',nn.Linear(1000,10))
print(vgg16_true)

#修改最后一行结构为out_feature=10
vgg16_false.classfier[6] = nn.Linear(4096,10)
print(vgg16_false)

你可能感兴趣的:(深度学习,pytorch,神经网络)