0.3.1转到0.4.1或更高版本
直接使用代码导入时常碰到 ‘BatchNorm2d’ object has no attribute ‘track_running_stats’的报错信息,这是由于0.3.1中的BN操作中没有配置track_running_stats参数,0.3.1中BatchNorm的定义如下
class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
$$ y = frac{x - mean[x]}{ sqrt{Var[x] + epsilon}} * gamma + beta$$Parameters:
num_features – num_features from an expected input of size batch_size x num_features x height x width
eps – a value added to the denominator for numerical stability. Default: 1e-5
momentum – the value used for the running_mean and running_var computation. Default: 0.1
affine – a boolean value that when set to True, gives the layer learnable affine parameters. Default: TrueShape:
Input: (N,C,H,W)
Output: (N,C,H,W) (same shape as input)
而在0.4.1中定义发生了变化
class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
If track_running_stats is set to False, this layer then does not keep running estimates,
and batch statistics are instead used during evaluation time as well.
所以使用0.4.1或以上版本导入0.3.1模型时需要对模型中的BN层添加track_running_stats参数,代码如下1
2
3
4
5
6
7
8
9
10
11
12def (module):
if isinstance(module, torch.nn.BatchNorm2d):
module.track_running_stats = True
else:
for name, module1 in module._modules.items():
module1 = recursion_change_bn(module1)
check_point = torch.load(check_point_file_path)
model = check_point['net']
for name, module in model._modules.items():
recursion_change_bn(model)
model.eval()
另外,也可以在导入模型处直接修改模型,模型的statedict本身可以理解为一个Orderdict,在模型中添加参数num_batches_tracked对应的值即可. 具体做法是在键值为running_var后添加一个键值为num_batches_tracked,值为0的Tensor. 具体代码如下1
2
3
4
5
6
7
8
9checkpoint = torch.load(checkpoint_path, map_location=device)
mapped_state_dict = OrderedDict()
for key, value in checkpoint['state_dict'].items():
print(key)
mapped_key = key
mapped_state_dict[mapped_key] = value
if 'running_var' in key:
mapped_state_dict[key.replace('running_var', 'num_batches_tracked')] = torch.zeros(1).to(device)
model.load_state_dict(mapped_state_dict)
0.3.1版本导入0.4.1以上版本模型0.4中使用设备:.to(device)
0.4中删除了Variable,直接tensor就可以
with torch.no_grad():的使用代替volatile;弃用volatile,测试中不需要计算梯度的话,用with torch.no_grad():
data改用.detach;x.detach()返回一个requires_grad=False的共享数据的Tensor,并且,如果反向传播中需要x,那么x.detach返回的Tensor的变动会被autograd追踪。相反,x.data()返回的Tensor,其变动不会被autograd追踪,如果反向传播需要用到x的话,值就不对了。
pytorch0.4有一些接口已经改变,且模型向下版本兼容,不向上兼容。
使用pytorch0.3导入pytorch0.4保存的模型时候在导入前添加如下代码段,解决的报错内容为(AttributeError: Can’t get attribute ‘_rebuild_tensor_v2’ on \lib\site-packages\torch\_utils.py'>),详情可对比查看_utils.py文件:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28# See https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560
import torch
# ***********pytorch0.3.1导入0.4.1以上版本模型时加入以下代码块**********
# 使用以下函数代替torch._utils中的函数(0.3.1中可能不存在或者接口不同导致的报错)
try:
torch._utils._rebuild_tensor_v2
except AttributeError:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
try:
torch._utils._rebuild_parameter
except AttributeError:
def _rebuild_parameter(data, requires_grad, backward_hooks):
param = torch.nn.Parameter(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
param._backward_hooks = backward_hooks
return param
torch._utils._rebuild_parameter = _rebuild_parameter
# ***********************************************************************
在导出为ONNX模型时还可能会报错存在多余的num_batches_tracked值, 错误代码为KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict', 此处的处理方式和上边的添加num_batches_tracked键值对应,删除该键值即可,具体代码如下1
2
3
4
5
6
7
8
9checkpoint = torch.load(checkpoint_path, map_location=device)
mapped_state_dict = OrderedDict()
for key, value in checkpoint['state_dict'].items():
print(key)
mapped_key = key
mapped_state_dict[mapped_key] = value
if 'num_batches_tracked' in key:
del mapped_state_dict[key]
model.load_state_dict(mapped_state_dict)由0.4.1导出为0.3.1的ONNX模型时,上述两段代码都需要加入
导出为1.0.0模型
pytorch1.0.0添加了torch.jit, 可以直接将模型和网络打包到模型文件中,而不需要在使用模型文件时导入网络定义,在模型的使用时变得更加方便了
模型的jit导出1
2
3
4
5def pth_to_jit(model, save_path, device="cuda:0"):
model.eval()
input_x = torch.randn(1, 3, 144, 144).to(device) # 输入大小
new_model = torch.jit.trace(model, input_x)
torch.jit.save(new_model, save_path)
jit模型导入使用1
2
3def load_jit(jit_model_path):
model = torch.jit.load(jit_model_path, map_location=torch.device('cuda:0'))
model.eval()