多GPU下训练时保存模型

在pytorch中,使用多GPU训练时需要用到 【nn.DataParallel】

在多GPU训练时,模型会被DataParallel进行封装,训练时保存的模型会多出来一个module,所以当预测模型时不需要多GPU,直接用单GPU加载参数时会报错。
因此,可以在训练时将保存模型的代码改为:`

if len(gpu_ids) > 1:
  t.save(model.module.state_dict(), "model.pth")
else:
  t.save(model.state_dict(), "model.pth")

参考:https://www.jb51.net/article/189297.htm

你可能感兴趣的:(深度学习,python,人工智能)