【Pytorch】保存和加载模型

官方:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

【参考:PyTorch保存和加载模型_正则化的博客-CSDN博客】

保存和加载权重参数

PyTorch 模型将学到的参数存储在内部状态字典中,称为 state _ dict。可以通过 torch.save 方法持久化这些内容:

#----把模型中的参数保存成字典的形式, 不保存网络模型的结构, 官方推荐----
torch.save(model.state_dict(), 'params_name.pth') #保存的文件名后缀一般是.pt或.pth

要加载模型权重,您需要首先创建相同模型的实例,然后使用 load _ state _ dict ()方法加载参数。

#----加载----
model=Model() #定义模型结构
model.load_state_dict(torch.load('params_name.pth'))  #加载模型参数

保存和加载带权重参数的模型

#----保存----
torch.save(model, 'model_name.pth')
#----加载----
model = torch.load('model_name.pth')

应用

保留验证集上最好的模型

【参考:Pytorch保留验证集上最好的模型_我是天才很好的博客-CSDN博客】

min_loss = 100000 # 随便设置一个比较大的数
for epoch in range(epochs):
    train()
    val_loss = val() # val()是验证函数 返回loss
    if val_loss < min_loss:
        min_loss = val_loss
        print("save model")
        torch.save(net.state_dict(),'model.pth')

【Pytorch】保存和加载模型_第1张图片
【Pytorch】保存和加载模型_第2张图片

在加载的模型基础上继续训练

【参考:Pytorch模型保存与加载,并在加载的模型基础上继续训练 - 简书】


你可能感兴趣的:(#,+,Pytorch,深度学习)