pytorch低版本加载高版本pytorch训练得到的模型,出现‘module’ object has no attribute ‘_rebuild_tensor_v2’错误

情景
使用pytorch0.3来加载Mobilenetv1的模型(用更高版本的pytorch训练得到的),出现“AttributeError: ‘module’ object has no attribute ‘_rebuild_tensor_v2’”错误。

分析
追根溯源,查看pytorch的源码,torch下__init__.py定义了__all__ = [**, ‘load’, **],然后是from .serialization import load,然后去看serialization.py的源码,load函数下面调用了_load函数,但是 我并没有找到_rebuild_tensor_v2函数。在最新版的pytorch源码下,_utils.py下存在_rebuild_tensor_v2函数,但是在低版本下没有_rebuild_tensor_v2函数。个人猜想,在调用torch.load的时候,在某处调用了_rebuild_tensor_v2,但是在低版本pytorch的_utils.py文件中没有定义该函数,因此需要自己重新定义。

解决
参考:https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560 给出的解决方案,在调用torch.load的py文件中,添加以下代码,自己定义_rebuild_tensor_v2函数:

import torch._utils
try:
    torch._utils._rebuild_tensor_v2
except AttributeError:
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
        tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        tensor._backward_hooks = backward_hooks
        return tensor
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

扩展阅读
这种解决版本不兼容冲突的方法也叫作:Monkey Patch。猴子补丁。
大概意思就是:在运行时对已有代码进行修改,达到hot patch的功能。
具体可参考:https://blog.csdn.net/fly910905/article/details/77152110

你可能感兴趣的:(Pytorch学习)