[深度学习笔记(3)]模型保存与加载


本系列是博主刚开始接触深度学习时写的一些笔记,写的很早了一直没有上传,趁着假期上传一下,作为分享,希望能帮助到你。

目录

一、模型保存

二、模型加载

1.加载模型

2.加载模型参数

总结


一、模型保存

保存模型/模型参数。

torch.save(obj, f, pickle_module = ,pickle_protocol=2)

        其中,obj是需要保存的对象,f是类文件对象或一个保存文件名的字符串,pickle_module指用于picking元数据和对象的模块,pickle_protocol指可以覆盖的默认参数。举例说明:

torch.save(model, ‘model.pt’)  #保存整个模型
Torch.save(model.state_dict(), ‘model.pt’)  #保存训练好的网络权重

二、模型加载

1.加载模型

torch.load(f, map_location=None,pickle_module = )

        其中,f是类文件对象或一个保存文件名的字符串,map_location指一个函数或字典规定如何映射存储设备,pickle_module指用于unpicking元数据和对象的模块(必须匹配序列化文件时的pickle_module)。

2.加载模型参数

torch.nn.Module.load_state_dict(state_dict, strict=True)

        其中,state_dict指保存parameters和persistent buffers的字典。只有包含了可学习参数的层(如卷积层、线性层等)和已注册的命令才有模型的state_dict入口。

举例说明:

#(1)
#保存整个模型
torch.save(model_object, ‘model.pth’)
#加载模型
model = torch.load(‘model.pth’)

#(2)
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))

#(2)的模型效果非常差,解决方法:
#(2)plus:
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))
model.eval() #固定dropout()和BN层

        其中,model.eval()的作用是固定dropout()和BN层。


总结

        以上就是今天要讲的内容,本文介绍了模型保存与加载的详细代码实现,希望能够帮助到你。如有错误,请及时指出,我们一起进步!

你可能感兴趣的:(深度学习,关于pytorch的tip,深度学习,pytorch,人工智能,python)