解决 RuntimeError: Error(s) in loading state_dict for : Missing key(s) in state_dict

使用torch.nn.DataParallel多卡训练模型之后,加载模型前也需要打开多卡读取模型。

我最近使用多卡训练了一个模型。保存的方式是state_dict的方式。
然后在加载模型的时候就一直出错。
raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
RuntimeError: Error(s) in loading state_dict for BertMultiClassification:
Missing key(s) in state_dict
意思是这个状态字典中缺少键。

输出模型需要的字典中key 的名称

bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model_path)
bert_config = BertConfig.from_pretrained(pretrained_model_path)
bert_model = BertModel.from_pretrained(pretrained_model_path)
bert_model = BertMultiClassification(bert_model, bert_config, len(hpo_tree.hpo2idx))
model_param_list = [p[0] for p in bert_model.named_parameters()]
print(model_param_list )

输出

”bert.embeddings.position_ids", "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bertgs.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight"

加载stat_dict 字典

load_dict = torch.load(bert_model_path)
print(load_dict.keys())

输出

”module.bert.embeddings.position_ids", "module.bert.embeddings.word_embeddings.weight", "module.bert.embeddings.position_e.weight", "module.bert.embeddings.token_type_embeddings.weight"

我们可以清楚的看到,使用过并行方法训练的模型的key 前面是多了”module.“
所以我们需要在加载模型之后也,也将模型转到多卡上面运行。将代码改成下面的就行。

bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model_path)
bert_config = BertConfig.from_pretrained(pretrained_model_path)
bert_model = BertModel.from_pretrained(pretrained_model_path)
bert_model = BertMultiClassification(bert_model, bert_config, 
#多卡运行
# 得先import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
if torch.cuda.device_count() > 1:
    bert_model=torch.nn.DataParallel(bert_model,device_ids=[0,1,2,3]).cuda() 
bert_model = bert_model.to(device)
# load 状态字典
bert_model.load_state_dict(torch.load(bert_model_path))

你可能感兴趣的:(Linux,pytorch)