解决torch.hub.load加载网络模型异常

1 torch.hub.load 加载网络模型错误

通过网络使用torch.hub.load加载模型代码如下:

self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)

运行网上的项目,经常会卡住或者超时,原因是 torch.hub.load 默认会去网上找模型,而github经常是不可访问的(需要走代理),从而导致网络异常,错误如下:

Traceback (most recent call last):
  File "/opt/pa_retrieve/preprocessor/remove_redundant_image.py", line 4, in 
    from model.dinov2_embeding_small import dinov2_embeding_small
  File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 42, in 
    dinov2_embeding_small = Dinov2EmbedingSmall()
  File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 20, in __init__
    self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 555, in load
    repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 199, in _get_cache_or_reload
    repo_owner, repo_name, ref = _parse_repo_info(github)
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 142, in _parse_repo_info
    with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 214, in urlopen
    return opener.open(url, data, timeout)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 517, in open
    response = self._open(req, data)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 534, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 494, in _call_chain
    result = func(*args)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1389, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1350, in do_open
    r = h.getresponse()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 1377, in getresponse
    response.begin()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 320, in begin
    version, status, reason = self._read_status()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 289, in _read_status
    raise RemoteDisconnected("Remote end closed connection without"
http.client.RemoteDisconnected: Remote end closed connection without response

2 torch.hub.load加载本地模型

通过代理下载模型和代码,模型存放在如下目录下:

/root/.cache/torch/hub/checkpoints/

工程代码存放在如下目录下:

/root/.cache/torch/hub/facebookresearch_dinov2_main/

更改模型加载的代码为本地加载,代码如下:

self.model = torch.hub.load('/root/.cache/torch/hub/facebookresearch_dinov2_main', 'dinov2_vits14', trust_repo=True, source='local').to(self.device)

再次运行程序,模型加载成功。

你可能感兴趣的:(AI运行环境,python,深度学习,torch)