PyTorch 多 GPU 训练保存模型权重 key 多了 ‘module.‘

一、问题表现

使用多 GPU 训练保存模型权重后,再次加载 state_dict 会出现 ‘“Missing key(s)” 错误,信息如下,可以发现预期的权重 key 比文件中保存的 key 少了 'module.' 。或者说,在多 GPU 训练的情况下,通过 torch.save() 保存的模型权重的 key 多了 'module.'。

RuntimeError: Error(s) in loading state_dict for LeNet:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "classifier.0.weight", "classifier.0.bias", "classifier.2.weight", "classifier.2.bias". 
	Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.classifier.0.weight", "module.classifier.0.bias", "module.classifier.2.weight", "module.classifier.2.bias". 
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 24763 closing signal SIGTERM

具体原因未知,官方文档也未给出更详细的例子,只有一般用法,如下。

import torch
import torchvision.models as models

# PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

# To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

可能是因为执行多 GPU 训练时,使用官方推荐的 python -m torch.distributed.lauch 或 torchrun 工具所致。

二、解决方法

针对以上问题有两种解决方法,可以分为加载权重后保存权重前

(1)加载权重后修改 key

首先通过 torch.load() 加载权重文件,然后遍历字典,如果 key 中包含 'module' 则将其删掉,参考这里。

weights_name = 'weights-ep10-1641471178.0502117.rank-0.pth'  
weights = torch.load(weights_name)

weights_dict = {}
for k, v in weights.items():
    new_k = k.replace('module.', '') if 'module' in k else k
    weights_dict[new_k] = v

model.load_state_dict(weights_dict)

(2)保存权重前增加 module

使用 torch.save() 保存权重时,通过 model.module.state_dict() 获取模型权重,而不是像官方示例中只用 model.state_dict() ,参考这里。

model_weights_name = "weights.pth"
torch.save(model.module.state_dict(), model_weights_name)

注意在多 GPU 训练情况下才会出现保存模型权重的 key 多了 'module.',以上两种方法选择其中一种即可,例如,当已经拿到多 GPU 训练的模型时,使用方法(1)比较好;如果重新训练模型,则可以直接使用方法(2)。

你可能感兴趣的:(PyTorch,pytorch,深度学习,人工智能)