pytorch预训练模型下载保存路径和路径更改

在模型的Finetune初始化的阶段:
预训练模型在线下载,下模型后的地址默认是:

~/.cache/torch/hub/checkpoints

预训练模型的网络可以通过下面的代码得到

net = torchvision.models.vgg16(pretrained=True)

如果没有预先下载好预训练模型,在运行这个代码后,自动下载预训练模型的。
如果国内的网络速度慢,建议先手动下载预训练模型,放入制定或默认下载目录下。
如果要更改路径,有两种办法:
第一种办法:
通过源代码提供的线索:当pretrained为True时,torch会调用torch.utils的load_state_dict_from_url函数,这个函数最终调torch.utils.model_zoo.load_url函数。其中的参数model_dir就是保存的目录,这里它默认会使用环境变量TORCH_HOME。
默认情况下环境变量TORCH_HOME的值为~/.cache,在windows下就是%USERPROFILE%.cache。

因此要修改PyTorch下载文件的保存路径,只要在下载模型前,修改环境变量TORCH_HOME的值即可。可以在操作系统里设置对应的环境变量。
还可以在代码里临时添加对应的条目,对应的代码如下:

import os
os.environ['TORCH_HOME']='E:/Data/torch-model'

注意:每次python重新启动都需要重新运行一遍,设置上。
在运行上面代码之后,重新加载模型时就会在这个目录下载和加载。
第二种办法:
需要到torch安装路径下找到hub.py,在hub.py,搜load_state_dict_from_url,成功找到如下代码:
在这里插入图片描述
model_dir参数即为下载模型的默认路径,直接将model_dir = None换成model_dir = 想要的模型下载绝对路径即可。
参考:https://blog.csdn.net/ProLover98/article/details/104792115

你可能感兴趣的:(pytorch,pytorch,深度学习,python,人工智能,机器学习)