【土堆pytorch实战】P25 vgg16模型的修改

P25现有模型的使用及修改

  • 查看vgg16模型的参数
import torchvision

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_True=torchvision.models.vgg16(pretrained=True)
print(vgg16_True)

【土堆pytorch实战】P25 vgg16模型的修改_第1张图片

  • 在vgg16基础上增加线性层
import torchvision
from torch import nn

vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
#增加一个线性层,使得in_features=1000,out_features=10
train_data=torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)

vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10)) #在vgg16的classifier下面增加线性层
print(vgg16_true)
vgg16_false.classifier[6]=nn.Linear(4096,10)  #修改vgg16——false的第六层
print(vgg16_false)

【土堆pytorch实战】P25 vgg16模型的修改_第2张图片
【土堆pytorch实战】P25 vgg16模型的修改_第3张图片

你可能感兴趣的:(pytorch实战,pytorch,深度学习)