pytorch 分布式训练GPU模型转CPU

       近期在公司实习遇到一个问题,训练时,采用的是分布式的GPU训练的模型,上线需要cpu版本的,因此测试时,模型载入出错,需要转成CPU版。转换方法如下:

model = torch.load(model_path)
d = collections.OrderedDict()
for key, value in model.state_dict().items():
    tmp = key[7:]
    d[tmp] = value

model.load_state_dict(d)

分布式GPU训练的模型的key值会多了module.,将它去掉,重新载入模型即可。

你可能感兴趣的:(pytorch)