torch0.4.0版本加载torch0.4.1版本训练的pytorch模型

利用torch0.4.0版本的pytorch加载在torch0.4.1版本上训练的pytorch模型时,会遇到如下报错:

RuntimeError: Error(s) in loading state_dict for DataParallel:
				Unexpected key(s) in state_dict: "xxxxx.num_batches_tracked"

如图所示:torch0.4.0版本加载torch0.4.1版本训练的pytorch模型_第1张图片
此时只需要在加载模型的时候修改一下keys的名称即可,如下所示:

state_dict = torch.load(args.model)
model.load_state_dict({k.replace('.num_batches_tracked', ''):v for k,v in state_dict['state_dict'].items()})

你可能感兴趣的:(Pytorch)