pytorch 清楚缓存_Pytorch快速下载预训练模型并修改保存路径

【Pytorch】快速下载预训练模型并修改保存路径

首次用Pytorch加载预训练模型,需要在线下载,但是下载速度比较慢。下载后会保存在本地缓存里。如果能直接加载本地下载好的模型就会快了,主要是个修改路径的问题。

所以要提升速度一般有两种方法:

1.修改torch源码,一次性改变下载url

2.将离线模型权重存到缓存文件夹里

参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法 - you-wh - 博客园

参考:【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径_ProLover98的博客-CSDN博客

参考:pytorch 加载(.pth)格式的模型_人工智能_u014264373的博客-CSDN博客(没有修改存储路径)

但是用云服务器时候这两种方法都有点问题。如果预训练模型的下载路径和存储路径能随用随改就最好了。

所以以vgg16为例,本文采用的方法是:

import torch

from torchvision import models

pthfile = 'file:///mnt/model/vgg16-397923af.pth'  #在下载好的pth文件路径前加file:///得到url

pthsavefile = '/mnt/vgg16-397923af.pth'  #这是模型保存的路径

model = models.vgg.vgg16(pretrained=False, progress=True) #定义一个不需要预训练的模型。如果pretrained=True就会自动下载了

state_dict = torch.utils.model_zoo.load_url(pthfile, model_dir=pthsavefile,

map_location=None, progress=True, check_hash=False)

# 从pthfile下载到pthsavefile。默认model_dir为none

model.load_state_dict(state_dict) # 读取下载好的模型

# 设置好参数就可以train了

model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)

模型可以任意换成别的,比如

models.vgg.vgg16

models.resnet.resnet18

models.resnet.resnext50_32x4d

你可能感兴趣的:(pytorch,清楚缓存)