RuntimeError: Error(s) in loading state_dict for DataParallel:

  • 错误原因是在train使用了单GPU,但在test里面使用多GPU。
RuntimeError: Error(s) in loading state_dict for DataParallel:
    Missing key(s) in state_dict: "module.encoder_stage1.0.weight".
    Unexpected key(s) in state_dict: "encoder_stage1.0.weight". 
  • code里这句话就是使用多GPU的意思
model = torch.nn.DataParallel(model, device_ids=args.gpu_id)

解决方法:

在前面添加上‘module.’

ckpt = checkpoint['net']
    new_ckpt = {}
    for k, v in ckpt.items():
        k = 'module.' + k
        new_ckpt[k]=v
    model.load_state_dict(new_ckpt)

你可能感兴趣的:(Errors,pytorch,深度学习,python)