pytorch保存和加载模型

之前想学习保存和加载模型的代码,在知乎上看到一个回答,发现两行代码就可以搞定,于是兴冲冲的加上了:

torch.save(model, "model.pth.tar") 
model_dict=torch.load("model.pth.tar")

然后就大胆的去训练了,结果训练结束,准备load时,发现load得到的结果,就只有模型的结构,参数完全没保存下来…
(哎,当时看到答主说这种方式是保存了整个网络,就以为整个网络必然包括参数啊,谁知道仅仅是结构)

于是换了一种方式:

checkpoint = {
    "model_struct": model,
    "model_param": model.state_dict(),
    "model_cfg": config}
torch.save(checkpoint, “model.ckpt")

这种方式是自己建立一个字典checkpoint,然后分别保存模型结构model_struct、模型参数model_param和相关配置model_cfg,然后保存为.ckpt文件(至于为何.pth.tar.ckpt到底有什么不一样,暂时还不清楚)

这次的教训就是,单保存模型是不能保存网络参数的,需要调用模型的.state_dict()属性将参数拿出来

参考文章:
https://zhuanlan.zhihu.com/p/38056115

你可能感兴趣的:(pytorch,深度学习,python)