Pytorch 单机加载分布式训练的模型

前言

在使用load_state_dict函数来加载模型的时候会出现各种各样的坑。比如会报这种错误

RuntimeError: Error(s) in loading state_dict for Net:
Missing key(s) in state_dict:

网上有人说,在load_state_dict函数中,设置strict为True,即load_state_dict(xxxx,True),但是这种方式会有Bug。在你加载分布式训练的模型的时候,参数会无效……

所有又有的人说
在load_state_dict(torch.load(‘net.pth’)前,增加model = nn.DataParallel(model)

但是如果我自己的机器并不支持分布式训练呢??会报这样的错

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

所以如果是单纯的想加载分布式训练保存的模型呢????

解决方法

https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-torch-nn-dataparallel-models

保存模型的时候不要用 torch.save(net.state_dict(), path

而是用

torch.save(net.module.state_dict(), path)

加一个module就好了。

如果是别人保存的模型,那就让他保存module的……

你可能感兴趣的:(【——机器学习相关——】)