pytorch 换版本_Pytorch 模型版本切换

0.3.1转到0.4.1或更高版本

直接使用代码导入时常碰到 ‘BatchNorm2d’ object has no attribute ‘track_running_stats’的报错信息,这是由于0.3.1中的BN操作中没有配置track_running_stats参数,0.3.1中BatchNorm的定义如下

class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)

$$ y = frac{x - mean[x]}{ sqrt{Var[x] + epsilon}} * gamma + beta$$Parameters:

num_features – num_features from an expected input of size batch_size x num_features x height x width

eps – a value added to the denominator for numerical stability. Default: 1e-5

momentum – the value used for the running_mean and running_var computation. Default: 0.1

affine – a boolean value that when set to True, gives the layer learnable affine parameters. Default: TrueShape:

Input: (N,C,H,W)

Output: (N,C,H,W) (same shape as input)

而在0.4.1中定义发生了变化

class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

If track_running_stats is set to False, this layer then does not keep running estimates,

and batch statistics are instead used during evaluation time as well.

所以使用0.4.1或以上版本导入0.3.1模型时需要对模型中的BN层添加track_running_stats参数,代码如下1

2

3

4

5

6

7

8

9

10

11

12def (module):

if isinstance(module, torch.nn.BatchNorm2d):

module.track_running_stats = True

else:

for name, module1 in module._modules.items():

module1 = recursion_change_bn(module1)

check_point = torch.load(check_point_file_path)

model = check_point['net']

for name, module in model._modules.items():

recursion_change_bn(model)

model.eval()

另外,也可以在导入模型处直接修改模型,模型的statedict本身可以理解为一个Orderdict,在模型中添加参数num_batches_tracked对应的值即可. 具体做法是在键值为running_var后添加一个键值为num_batches_tracked,值为0的Tensor. 具体代码如下1

2

3

4

5

6

7

8

9checkpoint = torch.load(checkpoint_path, map_location=device)

mapped_state_dict = OrderedDict()

for key, value in checkpoint['state_dict'].items():

print(key)

mapped_key = key

mapped_state_dict[mapped_key] = value

if 'running_var' in key:

mapped_state_dict[key.replace('running_var', 'num_batches_tracked')] = torch.zeros(1).to(device)

model.load_state_dict(mapped_state_dict)

0.3.1版本导入0.4.1以上版本模型0.4中使用设备:.to(device)

0.4中删除了Variable,直接tensor就可以

with torch.no_grad():的使用代替volatile;弃用volatile,测试中不需要计算梯度的话,用with torch.no_grad():

data改用.detach;x.detach()返回一个requires_grad=False的共享数据的Tensor,并且,如果反向传播中需要x,那么x.detach返回的Tensor的变动会被autograd追踪。相反,x.data()返回的Tensor,其变动不会被autograd追踪,如果反向传播需要用到x的话,值就不对了。

pytorch0.4有一些接口已经改变,且模型向下版本兼容,不向上兼容。

使用pytorch0.3导入pytorch0.4保存的模型时候在导入前添加如下代码段,解决的报错内容为(AttributeError: Can’t get attribute ‘_rebuild_tensor_v2’ on \lib\site-packages\torch\_utils.py'>),详情可对比查看_utils.py文件:1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28# See https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560

import torch

# ***********pytorch0.3.1导入0.4.1以上版本模型时加入以下代码块**********

# 使用以下函数代替torch._utils中的函数(0.3.1中可能不存在或者接口不同导致的报错)

try:

torch._utils._rebuild_tensor_v2

except AttributeError:

def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):

tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)

tensor.requires_grad = requires_grad

tensor._backward_hooks = backward_hooks

return tensor

torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

try:

torch._utils._rebuild_parameter

except AttributeError:

def _rebuild_parameter(data, requires_grad, backward_hooks):

param = torch.nn.Parameter(data, requires_grad)

# NB: This line exists only for backwards compatibility; the

# general expectation is that backward_hooks is an empty

# OrderedDict. See Note [Don't serialize hooks]

param._backward_hooks = backward_hooks

return param

torch._utils._rebuild_parameter = _rebuild_parameter

# ***********************************************************************

在导出为ONNX模型时还可能会报错存在多余的num_batches_tracked值, 错误代码为KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict', 此处的处理方式和上边的添加num_batches_tracked键值对应,删除该键值即可,具体代码如下1

2

3

4

5

6

7

8

9checkpoint = torch.load(checkpoint_path, map_location=device)

mapped_state_dict = OrderedDict()

for key, value in checkpoint['state_dict'].items():

print(key)

mapped_key = key

mapped_state_dict[mapped_key] = value

if 'num_batches_tracked' in key:

del mapped_state_dict[key]

model.load_state_dict(mapped_state_dict)由0.4.1导出为0.3.1的ONNX模型时,上述两段代码都需要加入

导出为1.0.0模型

pytorch1.0.0添加了torch.jit, 可以直接将模型和网络打包到模型文件中,而不需要在使用模型文件时导入网络定义,在模型的使用时变得更加方便了

模型的jit导出1

2

3

4

5def pth_to_jit(model, save_path, device="cuda:0"):

model.eval()

input_x = torch.randn(1, 3, 144, 144).to(device) # 输入大小

new_model = torch.jit.trace(model, input_x)

torch.jit.save(new_model, save_path)

jit模型导入使用1

2

3def load_jit(jit_model_path):

model = torch.jit.load(jit_model_path, map_location=torch.device('cuda:0'))

model.eval()

你可能感兴趣的:(pytorch,换版本)