pytorch 多gpu训练, 单cpu测试

虽然一般是继续使用gpu测试,但是可能还是会有这种需求的时候

 

# original saved file with DataParallel
checkpoint =  "checkpoints/{}_best_loss.pth.tar".format(model_name)
best_model = torch.load(checkpoint)
state_dict = best_model["state_dict"]
# create new OrderedDict that does not contain `module.`
new_state_dict = OrderedDict()
for k, v in state_dict.items():
	name = k[7:] # remove `module.`
	new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

 

你可能感兴趣的:(pytorch 多gpu训练, 单cpu测试)