RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 2.

解决报错RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 2. Please use torch.load with map_location to map your storages to an existing device.

这个报错的原因是训练的时候用了编号为2的GPU,但是在你用pytorch load这个模型运行的时候使用的不是2号GPU。

解决方法很简单,在load的时候加一个map_location:

#model.load_state_dict(torch.load(model_path))  #会报错的写法
model.load_state_dict(torch.load(model_path,map_location={'cuda:2': 'cuda:0','cuda:1': 'cuda:0'}))  #将对应的gpu编号映射到正确的gpu上

你可能感兴趣的:(pytorch,pytorch,人工智能,python)