pytorch现有模型的使用和修改

#model_pretrained


VGG16
vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16_true = torchvision.models.vgg16(pretrained = True)
print(vgg16_true)


train_data = torchvision.datasets.CIFAR10("./data",train=True, transform=torchvision.transforms.ToTensor(),download=Ture)
#增加线性层
vgg10_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

你可能感兴趣的:(pytorch,人工智能,python)