执行model.load_state_dict报错map_location=torch.device(‘cpu‘)以及Unexpected key(s) in state_dict

本蒟蒻原先是在有GPU的服务器上训练得到了模型,现在想在没有GPU的服务器部署网页,需要加载该模型,当运行了下面代码时

model.load_state_dict(torch.load(output_dir))

报错信息如下

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

是因为没有将state_dict转移到CPU上,在PyTorch中,可以通过指定map_location参数为torch.device('cpu')来实现。

model.load_state_dict(torch.load(output_dir, map_location=torch.device('cpu')))

修改后再次运行到这一句,发现又有错误,报错信息如下:

model.load_state_dict(torch.load(output_dir, map_location=torch.device('cpu')))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Model:
        Unexpected key(s) in state_dict: "encoder.embeddings.position_ids". 

我打印出了两台服务器中state_dict的内容

for k, v in torch.load(output_dir, map_location=torch.device('cpu').items():
    print(k)

打印的内容都是一样的,都是下面的内容

encoder.embeddings.position_ids
encoder.embeddings.word_embeddings.weight
encoder.embeddings.position_embeddings.weight
encoder.embeddings.token_type_embeddings.weight
encoder.embeddings.LayerNorm.weight
encoder.embeddings.LayerNorm.bias
encoder.encoder.layer.0.attention.self.query.weight
encoder.encoder.layer.0.attention.self.query.bias
encoder.encoder.layer.0.attention.self.key.weight
encoder.encoder.layer.0.attention.self.key.bias
encoder.encoder.layer.0.attention.self.value.weight
encoder.encoder.layer.0.attention.self.value.bias
encoder.encoder.layer.0.attention.output.dense.weight
encoder.encoder.layer.0.attention.output.dense.bias
...
内容太多不展示了

仔细查看提示信息,发现是有一个未知的Key,我把它删掉竟然就跑通了

new_state_dict = torch.load(output_dir, map_location=torch.device('cpu'))
del new_state_dict['encoder.embeddings.position_ids']
model.load_state_dict(new_state_dict)

你可能感兴趣的:(人工智能后门和对抗攻击,深度学习,人工智能)