解决RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: “conv1.0...

项目场景:

在多GPU环境下用Pytorch训练的Resnet分类网络


问题描述

卷积神经网络ResNet训练好之后,测试环境或测试代码用了单GPU版或CPU版,在加载网络的时候报错,报错处代码为:

net.load_state_dict(torch.load(args.weights))

报错如下:

RuntimeError: Error(s) in loading state_dict for ResNet: 
	Missing key(s) in state_dict: "conv1.0.weights", "conv1.1.weights", "conv1.1.bias", ...

原因分析:

出现这种报错的原因主要是,state_dict加载模型权重时,参数不匹配。可能是PyTorch版本环境不一致、torch.nn.DataParallel()关键字不匹配、训练环境与测试环境GPU不同。

我遇见这种报错,一次是因为GPU进行训练,CPU进行测试;另一次是多GPU进行训练,测试时对GPU部分的处理,没有按照训练时做多GPU处理,是单GPU。


解决方案:

情况一:多GPU训练,测试时代码为单GPU
解决方案:在环境中定义GPU为多GPU,定义关键字torch.nn.DataParallel()

import torch.nn as nn
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"

if __name__ == '__main__':
	# 此处为测试代码对接收参数的初始化
	parser = argparse.ArgumentParser()
	parser.add_argument('-gpu', action='store_true', default=True, help='use gpu or not')
	......
	......
	......

	# 加载网络
	net = get_network(args)
	
	# 以下为多GPU环境增加代码,以解决本博客中提到的报错问题
	net = net.cuda()
	net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
	
	# 以下为原始测试代码部分
	net.load_dict(torch.load(args.weights))
	net.eval()
	......
	......
	......
		

情况二:GPU训练,测试时代码为CPU
解决方案:在加载模型权重时,设置 map_location='cpu'

if __name__ == '__main__':
	# 此处为测试代码对接收参数的初始化
	parser = argparse.ArgumentParser()
	parser.add_argument('-gpu', action='store_true', default=True, help='use gpu or not')
	......
	......
	......

	# 加载网络
	net = get_network(args)
	
	
	# 以下为原始测试代码部分
	# net.load_dict(torch.load(args.weights))
	# 改为cpu模式
	net.load_dict(torch.load(args.weights, map_location='cpu'))
	net.eval()
	......
	......
	......
		

情况二还有可能出现的报错提示如下:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

参考博文:

please use torch.load with map_location=torch.device(‘cpu‘) to map your storages to the CPU.

还有训练与测试参数不匹配的情况可以参考博文:

【错误记录】RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module

你可能感兴趣的:(pytorch,深度学习,神经网络,卷积神经网络,python)