【pytorch】pytorch模型保存技巧

Pytorch会把模型相关信息保存为一个字典结构的数据,以用于继续训练或者推理。

1 保存与加载模型参数

        这是最常见的模型保存与加载方式,保存方式如下:

state = model.state_dict()
torch.save(state, ‘xxx.pth’)

        模型参数加载之前需要先定义模型的网络结构,假设已定义好的网络结构为model。那么模型参数加载方式如下:

checkpoint = torch.load('xxx.pth', map_location='cpu')
model.load_state_dict(checkpoint)
model.cuda()
model.train()/model.eval()

2 保存与加载训练参数

        除了模型参数之外,torch还可以保存其他训练相关参数,例如学习率、优化器信息等,保存方式如下:

state = {'checkpoint':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num}
torch.save(state, ‘xxx.pth’)

        模型参数加载之前需要先定义模型的网络结构,假设已定义好的网络结构为model。那么模型参数加载方式如下:

checkpoint = torch.load('xxx.pth', map_location='cpu')
model.load_state_dict(checkpoint[‘checkpoint’])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
model.cuda()
model.train()/model.eval()

3 保存模型结构与参数

        上述方法在加载前必须先定义模型的网络结构,为了不需要单独加载模型的网络结构可以把整个模型保存起来。保存方式如下:

torch.save(model, ‘xxx.pth’)

        模型加载时不需要先定义网络结构,直接加载整个网络即可,加载方式如下:

model = torch.load('xxx.pth', map_location='cpu')
model.cuda()
model.train()/model.eval()

        同样地,也可以类似2中做法把其他参数保存进来。

4 模型重新保存

        保存模型通常是为了继续训练,或者是为了算法推理。特别是在算法推理过程中,损失函数相关定义是不需要的。特别是一些特别的损失函数,需要依赖编译的环境,但这在模型推理时又不需要。那么,采用3中的保存方式通常会提醒模型模块的缺失。做法是:

1、先在模型网路定义程序中注释掉不需要的模块;
2、定义模型网络结构model,用1或2中的方式加载模型;
3、用3中的方式再次保存加载后的模型,这样之后就可以重复用3中的方式加载保存后的模型了。

python三维点云从基础到深度学习_Coding的叶子的博客-CSDN博客_3d点云 python从三维基础知识到深度学习,将按照以下目录持续进行更新。更新完成的部分可以在三维点云专栏中查看。https://blog.csdn.net/suiyingy/category_11740467.htmlhttps://blog.csdn.net/suiyingy/category_11740467.html1、点云格式介绍(已完成)常见点云存储方式有pcd、ply、bin、txt文件。open3d读写pcd和plhttps://blog.csdn.net/suiyingy/article/details/124017716

更多三维、二维感知算法和金融量化分析算法请关注“乐乐感知学堂”微信公众号,并将持续进行更新。

你可能感兴趣的:(Pytorch,深度学习环境,python,pytorch,深度学习,神经网络,python)