Pytorch之GPU模型加载在CPU上

直接在GPU上加载:

 pretrain = torch.load(opt.pretrain_path)
 model.load_state_dict(pretrain['state_dict'])

将GPU模型加载在CPU上:

 pretrain = torch.load(opt.pretrain_path, map_location=lambda storage, loc: storage)
 from collections import OrderedDict
 new_state_dict = OrderedDict()
 for k, v in pretrain.items():
     if k=='state_dict':
         state_dict=OrderedDict()
         for keys in v:
             name = keys[7:]# remove `module.`
             state_dict[name] = v[keys]
         new_state_dict[k]=state_dict
     else:
         new_state_dict[k] = v
 model.load_state_dict(new_state_dict['state_dict'])

你可能感兴趣的:(pytorch)