【pytorch教程】模型保存和加载

一、参考资料

[译]保存和加载模型
save and load PyTorch tensors

二、模型保存与加载

模型格式:pt,pth。这几种模型文件格式没有区别,只是后缀名不同而已。

方式一(推荐方式)

仅保存模型权重参数,不保存模型结构

# 保存
torch.save(model.state_dict(), 'params.pt')

# 加载
model = My_model(*args, **kwargs)  # model模型实例化,重构模型结构
model.load_state_dict(torch.load('params.pt'))  # 根据模型结构,加载模型参数
model.eval()  # 设置 dropout 和 batch normalization 层为评估模式。如果不这么做,可能导致模型推断结果不一致。

方式二

保存/载入整个pytorch模型

  • 以这种方式保存模型将使用Pythonpickle模块保存整个model的状态;

  • 保存/加载过程使用最直观的语法,涉及的代码量最少;

  • 缺点:(1)序列化数据绑定到特定的类;(2)保存模型时,使用确切目录结构。因此,当在其他项目中使用或重构后,您的代码可能会以各种方式中断;

    # 保存
    torch.save(model, mymodel.pth)
    
    # 加载
    model = torch.load(mymodel.pth)  # 不需要重构模型结构,直接load即可
    model.eval()
    

三、可能出现的问题

  • pytorch模型加载问题

    raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
    RuntimeError: yolov5s.pt is a zip archive (did you mean to use torch.jit.load()?)
    
    参考资料 [RuntimeError: xxx.pth is a zip archive (did you mean to use torch.jit.load()?)](https://blog.csdn.net/studyeboy/article/details/116451980)
    错误原因:
    pytorch版本不匹配的问题。比如,用torch 1.8.1训练保存的模型,用torch 1.1.0进行模型加载。PyTorch的1.6版本将torch.save切换为使用新的基于zipfile的文件格式,但 torch.load仍然保留以旧格式加载文件的功能。
    
    方法一:
    保存模型时,传递kwarg _use_new_zipfile_serialization = False参数,使用旧格式。
    torch.save(net.state_dict(), 'model.pth', _use_new_zipfile_serialization=False)
    
    方法二:
    保存模型时,将压缩格式转换为非压缩格式
    import torch
    from model import U2NET
    
    net = U2NET(3, 1)
    state_dict = torch.load('model.pth')
    net.load_state_dict(state_dict)
    torch.save(net.state_dict(), 'model.pth',_use_new_zipfile_serialization=False)
    
    方法三:
    下载新版本的pytorch
    

你可能感兴趣的:(深度学习,pytorch)