pytorch加载模型错误 Missing key(s) RuntimeError: Error(s) in loading state_dict for 多卡加载错误

  • torch.load() 报错 Missing key(s) pytorch

错误情况: 在加载预训练模型时出错

RuntimeError: Error(s) in loading state_dict for :

Missing key(s) in state_dict: “features.0.weight” …

Unexpected key(s) in state_dict: “module.features.0.weight” …

错误原因:

使用nn.DataParallel包装后的模型参数的关键字会比没用 nn.DataParallel 包装的模型参数的关键字前面多一个"module."

解决方法:

  1. 使用 net 加载 nn.DataParallel(net) 训练出来的模型:

    1. 把 module. 删掉

      代码出处

      # original saved file with DataParallel
      state_dict = torch.load('model_path')
      # create new OrderedDict that does not contain `module.`
      from collections import OrderedDict
      new_state_dict = OrderedDict()
      for k, v in state_dict.items():
          name = k[7:] # remove `module.`
          new_state_dict[name] = v
      # load params
      net.load_state_dict(new_state_dict)
      

      代码出处

      checkpoint = torch.load('model_path')
      for key in list(checkpoint.keys()):
          if 'model.' in key:
              checkpoint[key.replace('model.', '')] = checkpoint[key]
              del checkpoint[key]
      
      net.load_state_dict(checkpoint)
      
    2. 加载模型时使用 nn.DataParallel

      checkpoint = torch.load('model_path')
      net = torch.nn.DataParallel(net)
      net.load_state_dict(checkpoint)
      
  2. 使用 nn.DataParallel(net) 加载 net 训练出的模型:

    • 保存权重前增加 module

      使用 torch.save() 保存权重时,通过 model.module.state_dict() 获取模型权重

      torch.save(net.module.state_dict(), 'model_path')
      
    • 在使用nn.DataParallel之前就先读取模型,然后再使用nn.DataParallel

      net.load_state_dict(torch.load('model_path'))
      net = nn.DataParallel(net, device_ids=[0, 1]) 
      
    • 手动添加 module.

      net = nn.DataParallel(net) 
      from collections import OrderedDict
      new_state_dict = OrderedDict()
      state_dict =savepath #预训练模型路径
      for k, v in state_dict.items():
      	# 手动添加“module.”
          if 'module' not in k:
              k = 'module.'+k
          else:
          # 调换module和features的位置
              k = k.replace('features.module.', 'module.features.')
          new_state_dict[k]=v
      
      net.load_state_dict(new_state_dict)
      
      

你可能感兴趣的:(pytorch笔记)