Pytorch 预训练模型下载和加载

PyTorch 加载和下载预训练模型可参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法

- 下载地址

常用预训练模型在这里面:https://github.com/pytorch/vision/tree/master/torchvision/models

但是上述网址只有常见的 backbone (vgg, resnet, densenet, alexnet),在 GitHub 上,还找到了一个项目,提供 NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN 等预训练模型的下载:https://github.com/Cadene/pretrained-models.pytorch

具体下载位置是:https://data.lip6.fr/cadene/pretrainedmodels/

- 加载预训练模型

一般使用的是使用 model.load_state_dict() 函数。

model_urls = {  'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',}
def resnet50(pretrained=False, **kwargs):
	model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
	if pretrained:
		model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
	return model

此时它会到指定的网站下载预训练模型到本地缓存中,本地缓存的位置(Linux系统)一般在:

.cache/torch/checkpoints

PyTorch 在加载模型时候首先检查本地缓存是否已经存在预训练模型,所以在本地缓存汇总预先放入已经下载的模型可快速加载模型。

如果需要更改预训练模型的位置,可以在文件开头加入:

os.environ['TORCH_HOME']= './pretrained_models/'

pretrained_models 文件夹下新建一个 checkpoints 文件夹并把预训练模型放入即可。

- 参考

  1. pytorch预训练模型下载URL及加载调用方法
  2. pytorch学习笔记之加载预训练模型

你可能感兴趣的:(#,Python模块有关问题)