【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

说明

使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoints下,便于统一管理,我决定修改model的存放路径,在网上找了很久都没有很好的解决方法,只能自己尝试,现将解决方案给出,供大家参考~

操作环境

  • windows10 + Anaconda
  • torch:1.1.0
  • torchvision:0.3.0

加载方式

以加载vgg16为例,首先定义网络结构

class CNN(nn.Module):

    def __init__(self, usegpu=True):
        super(CNN, self).__init__()

        self.model = models.__dict__['vgg16'](pretrained=True)
        self.model = nn.Sequential(*list(self.model.children())[0]) #TODO resnet50 :-5
        self.model = nn.Sequential(*list(self.model.children())[:16])

    def forward(self, x):
        x = self.model(x)

        return x

在执行model = CNN()时即会触发下载动作,此时会默认下载到C:\用户名\.cache\torch\.checkpoints

修改方法

总的原则就是修改源码

  1. 经过这篇博客的介绍,发现pytorch的默认下载路径是由load_state_dict_from_url函数进行控制,那么就好办了,只需要找到这个函数进行修改即可
  2. 由于我是下载vgg16,所以我先找到vgg.py源码,位于python路径/torchvision/models/vgg.py(其中python路径即是你安装python所在的地址,假设你是用Anaconda创建了一个名为envtest的虚拟环境,那么所有安装的库都会在Anaconda/envs/envtest/Lib/site-packages这个文件夹下)
  3. 接下来就是套娃操作
    1. vgg.py直接搜索load_state_dict_from_url
      发现有如下语句
      vgg.py
      所以需要在本目录下找到utils.py,看看里面有没有load_state_dict_from_url函数
    2. 打开utils.py,只有如下代码
      utils.py
      显然,需要到torch安装路径下找到hub.py
      3.找到hub.py,搜索load_state_dict_from_url,成功找到如下代码
      【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)_第1张图片
      根据描述,可发现model_dir参数即为下载模型的默认路径,所以直接将model_dir = None换成model_dir = 想要的模型下载绝对路径即可,感兴趣的同学可以仔细专研,这里就不过多阐述

后记

这里说明一下我上面这么做的原因:

  1. 为什么需要找到源码写死?直接使用函数修改不行吗?
    因为考虑到后续可能还要下载其他预训练模型,与其每次进行修改,还不如一次性写死
  2. 为什么需要绝对路径,相对路径不行吗?
    相对路径也可以,但考虑到绝对路径更加直观,所以我这里使用的绝对路径

你可能感兴趣的:(Deep,Learning)