使用多 GPU 训练保存模型权重后,再次加载 state_dict 会出现 ‘“Missing key(s)” 错误,信息如下,可以发现预期的权重 key 比文件中保存的 key 少了 'module.' 。或者说,在多 GPU 训练的情况下,通过 torch.save() 保存的模型权重的 key 多了 'module.'。
RuntimeError: Error(s) in loading state_dict for LeNet:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "classifier.0.weight", "classifier.0.bias", "classifier.2.weight", "classifier.2.bias".
Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.classifier.0.weight", "module.classifier.0.bias", "module.classifier.2.weight", "module.classifier.2.bias".
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 24763 closing signal SIGTERM
具体原因未知,官方文档也未给出更详细的例子,只有一般用法,如下。
import torch
import torchvision.models as models
# PyTorch models store the learned parameters in an internal state dictionary, called state_dict. These can be persisted via the torch.save method:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
# To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method.
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
可能是因为执行多 GPU 训练时,使用官方推荐的 python -m torch.distributed.lauch 或 torchrun 工具所致。
针对以上问题有两种解决方法,可以分为加载权重后和保存权重前。
(1)加载权重后修改 key
首先通过 torch.load() 加载权重文件,然后遍历字典,如果 key 中包含 'module' 则将其删掉,参考这里。
weights_name = 'weights-ep10-1641471178.0502117.rank-0.pth'
weights = torch.load(weights_name)
weights_dict = {}
for k, v in weights.items():
new_k = k.replace('module.', '') if 'module' in k else k
weights_dict[new_k] = v
model.load_state_dict(weights_dict)
(2)保存权重前增加 module
使用 torch.save() 保存权重时,通过 model.module.state_dict() 获取模型权重,而不是像官方示例中只用 model.state_dict() ,参考这里。
model_weights_name = "weights.pth"
torch.save(model.module.state_dict(), model_weights_name)
注意在多 GPU 训练情况下才会出现保存模型权重的 key 多了 'module.',以上两种方法选择其中一种即可,例如,当已经拿到多 GPU 训练的模型时,使用方法(1)比较好;如果重新训练模型,则可以直接使用方法(2)。