【迁移学习】pytorch中如何加载已经训练好的模型

torchvisionmodels中包含很多用于图像分类、视频分类、目标检测等模型,例如vgg、resnet、inception v3等,我们既可以加载已经训练好的模型(预训练模型均是在ImageNet上进行训练的),也可以加载未经训练的模型,方法有两种,拿vgg来说:

torchvision.models.vgg19(pretrained=False, progress=True, **kwargs)

pretrained: 为True时,返回在ImageNet上的预训练模型。
progress: 为True时,返回下载进度条。

这种方法有可能会受到网络的影响,从而下载缓慢。另一种方法是先把预训练好的模型下载下来,然后加载一个pretrained=False的vgg模型,然后再把下载好的预训练模型的所有参数恢复到 “空壳”vgg中。

各种vgg模型的下载地址 通过github上的网址下载到指定文件夹中,模型后缀为pth

pretrained_net = torchvision.models.vgg19(pretrained=False)
load_model = torch.load('./vgg/vgg19-dcbb9e9d.pth')
pretrained_net.load_state_dict(load_model)

通过torch.load装载下载好的vgg预训练模型,然后通过load_state_dict进行参数拷贝,将vgg19-dcbb9e9d.pth的参数拷贝至pretained_net中。

你可能感兴趣的:(深度学习,Pytorch,迁移学习,pytorch,pytorch加载预训练模型)