pytorch多卡并行模型的保存与载入

pytorch多卡并行模型的保存与载入

当模型是在数据并行方式在多卡上进行训练的训练和保存,那么载入的时候也是一样需要是多卡。并且,load_state_dict()函数的调用要放在DataParallel()之后,而model.cuda()所在的位置无影响。

model = DefinedNetwork()
model = torch.nn.parallel.DataParallel(model, device_ids = [0,1])
model.load_state_dict(torch.load("model_best.pth"))
model.eval()
model.cuda()

你可能感兴趣的:(机器学习,python,pytorch,pytorch,多卡模型载入)