'BatchNorm2d' object has no attribute 'track_running_stats'

 'BatchNorm2d' object has no attribute 'track_running_stats'

还不知道什么原因:

后来发现是老版本训练的权重,用0.4.0以后版本打开,这个变量没有,不能兼容。

我的解决方法:

load模型后,报存pth,再新建网络,加载模型,理论上能用了。

import pickle
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
self.net = torch.load(graph_path, map_location=lambda storage, loc: storage, pickle_module=pickle)

try:
    torch.save(self.net.state_dict(), '111.pth')
except Exception as e:
    print(e)
    torch.save(self.net.module.features.state_dict(), '111.pth')
# self.net = torch.load(graph_path, map_location=lambda storage, loc: storage)
self.net.eval()

 

用pytorch加载训练好的模型的时候遇到了如下的问题:

AttributeError: 'module' object has no attribute '_rebuild_tensor_v2'

到网上查了一下是由于训练模型时使用的是新版本的pytorch,而加载时使用的是旧版本的pytorch。

解决方法:

1、既然是pytorch版本较老,那最简单的解决方法当然是简单的升级一下pytorch就ok了。

这个没试过:

2、国外的大神给了另一种解决方法,就是在程序开头添加下面的代码,即可以使老版本pytorch兼容新版本pytorch,参考链接

https://discuss.pytorch.org/t/batchnorm2d-object-has-no-attribute-track-running-stats/17525/15

这个貌似管用:

 def recursion_change_bn(module):
    if isinstance(module, torch.nn.BatchNorm2d):
        module.track_running_stats = 1
    else:
        for i, (name, module1) in enumerate(module._modules.items()):
            module1 = recursion_change_bn(module1)
    return module
and
use it when you load model

check_point = torch.load(check_point_file_path)
model = check_point['net']
for i, (name, module) in enumerate(model._modules.items()):
    module = recursion_change_bn(model)
model.eval()

 

你可能感兴趣的:(torch)