PyTorch多GPU模型保存和加载的一个注意事项-Unexpected key(s) in state_dict

用PyTorch加载已经保存好的模型参数文件时遇到一个bug:

Unexpected key(s) in state_dict: “module.features. …”.,Expected “.features…”

意思是从本地文件中加载模型的state_dict时,state_dict的key值不匹配。查了一些资料后,发现是PyTorch多gpu保存的问题,导致保存下来的state_dict中的key比原来都多了一个module,因此出现了上述这个问题。下面简单验证一下。

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        
        self.conv = nn.Conv2d(1, 3, 3, 1, 1, bias=False)
        self.model = nn.Sequential(
        	nn.Conv2d(3, 1, 1, 1, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )
        
     def forward(self, x):
		self.model(self.conv(x))

先查看直接定义后的结果:

model = net()
print(model)

上述代码的结果为:

net(
	(conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (model): Sequential(
    	(0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(1, eps=1e-5, momentum=0.1, affine=True, tracking_running_stats=True)
        (2): ReLU(inplace)
    )
)

再分别查看将模型并行和转移到GPU上的结果:

先单独查看将模型转移到GPU上的结果:

if torch.cuda.is_available():
	model = model.cuda(0)
	print(model)

上述代码的结果和之前的结果相同,在此就不贴结果了。说明将模型转移到GPU上的操作是没有影响的。

再来看模型并行的结果:

if torch.cuda.device_count() > 1:
	model = nn.DataParallel(model)
    print(model)
    model = model.cuda()
    print(model)

上述代码两个print的结果相同,均为:

DataParallel(
	(module):net(
    	(conv): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    	(model): Sequential(
    		(0): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
        	(1): BatchNorm2d(1, eps=1e-5, momentum=0.1, affine=True, tracking_running_stats=True)
        	(2): ReLU(inplace)
   		 )
    )
)

由上述代码可知,模型并行使原模型外被包裹了一层module,会影响到模型的保存和加载:

先保存并行的模型:

model = net()
if torch.cuda.device_count() > 1:
	model = nn.DataParallel(model).cuda()
    torch.save(model.state_dict(), 'demo.pth')

若直接加载:

model2 = net()
model2.load_state_dict(torch.load('demo.pth'))

此时就会遇到文章开头说的错误。解决这个问题有两种方法:

  1. 先将模型并行,再加载权重:

    model2 = net()
    model2 = nn.DataParallel(model2).cuda()
    model2.load_state_dict(torch.load('demo.pth'))
    
  2. 去掉本地权重文件key值中的module:

    model2 = net()
    model2.load_state_dict({k.replace('module.', ''):v for k, v in torch.load('demo.pth').items()})
    model2 = nn.DataParallel(model2).cuda()
    

因为模型始终是要并行的,所以推荐方法一,不增加额外的代码 ,注意一下模型并行和加载权重文件的顺序即可。

你可能感兴趣的:(PyTorch)