Pytorch学习笔记(6) Pytorch模型以及相关参数的保存与加载

当模型训练完成或者训练到一半,我们需要将模型保存。这里介绍如何保存模型的参数以及其他信息。当需要使用的使用如何加载。

一、模型的保存与加载

利用PyTorch可以进行模型的保存和加载,主要有以下两种方式。

方法1: 保存于加载整个模型

model = MyModel()
Path = "my_params,pkl"
# 保存模型
torch.save(model,Path)

# 加载模型
model = torch.load(Path)

方法2:保存模型的参数

# 保存模型参数
torch.save(model.state_dict(), Path)

# 加载模型参数,并且载入模型中
params = torch.load(Path)
model = model.load_state_dict(params)

一般推荐第二种方法,和第一种方法相比,第二种方式只保存模型的参数,节省空间,而且灵活性更高。当然,第一种方法的优点在于,你可以通过加载直接使用,不需要先初始化模型。你可以直接把这个文件发送给其他人,其他人可以直接使用。

二、 同时保存其他参数

在训练过程中,我们不仅仅需要保存模型的参数,有时候也可能保存其他参数,比如说优化器、损失函数的参数,训练过程中accuracy、epoch、learnrating等常数信息也可以使用torch.save()保存。

# optim是优化函数,loss是损失函数
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optim_state_dict': optim.state_dict(),
            'loss_state_dict': loss.state_dict(),
            'best_prec1': best_prec1,
        }, 'checkpoint.tar' )

# 加载, params就是一个字典,像使用字典一样获取参数
params = torch.load(Path)
epoch = params["epoch"]
optim.load_state_dict(params["optim_state_dict"])
model.load_state_dict(params["optim_state_dict"])

其实model.state_dict() 的值也是字典,键是每一层神经网络的名字,所以可以根据每一层的名字,取用参数,这样就会十分的灵活。当我们

你可能感兴趣的:(Pytorch学习笔记(6) Pytorch模型以及相关参数的保存与加载)