本系列是博主刚开始接触深度学习时写的一些笔记,写的很早了一直没有上传,趁着假期上传一下,作为分享,希望能帮助到你。
目录
一、模型保存
二、模型加载
1.加载模型
2.加载模型参数
总结
保存模型/模型参数。
torch.save(obj, f, pickle_module = ,pickle_protocol=2)
其中,obj是需要保存的对象,f是类文件对象或一个保存文件名的字符串,pickle_module指用于picking元数据和对象的模块,pickle_protocol指可以覆盖的默认参数。举例说明:
torch.save(model, ‘model.pt’) #保存整个模型
Torch.save(model.state_dict(), ‘model.pt’) #保存训练好的网络权重
torch.load(f, map_location=None,pickle_module = )
其中,f是类文件对象或一个保存文件名的字符串,map_location指一个函数或字典规定如何映射存储设备,pickle_module指用于unpicking元数据和对象的模块(必须匹配序列化文件时的pickle_module)。
torch.nn.Module.load_state_dict(state_dict, strict=True)
其中,state_dict指保存parameters和persistent buffers的字典。只有包含了可学习参数的层(如卷积层、线性层等)和已注册的命令才有模型的state_dict入口。
举例说明:
#(1)
#保存整个模型
torch.save(model_object, ‘model.pth’)
#加载模型
model = torch.load(‘model.pth’)
#(2)
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))
#(2)的模型效果非常差,解决方法:
#(2)plus:
#保存参数
torch.save(model_object.state_dict(), ‘params.pth’)
#加载模型
model_object=model()
model_object.load_state_dict(torch.load(‘params.pth))
model.eval() #固定dropout()和BN层
其中,model.eval()的作用是固定dropout()和BN层。
以上就是今天要讲的内容,本文介绍了模型保存与加载的详细代码实现,希望能够帮助到你。如有错误,请及时指出,我们一起进步!