pytorch实现模型的保存和加载

第一种方式

# 获得模型的参数和buffer量
path = "state_dict_model.pt"

# 保存
torch.save(model.state_dict(),path)

# 加载
model = Network(input_num)
model.load_state_dict(torch.load(path))
# 将内部的training参数 设置为FALSE 这样在直接使用模型进行预测时
# 就不再继续计算梯度值
model.eval()

第二种方式

# 对整个模型进保存和加载
path = "entire_model.pt"

# 保存模型
torch.save(model, path)

# 加载模型
model = torch.load(path)
model.eval()

第三种方式

# 保存checkpoint
path = 'model.pt'
torch.save(
    {
        'epoch':epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict':optimizer.state_dict(),
        'loss': loss_fn
    },path
)

# 加载
model = Network(input_num)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=lr)

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()
# or
model.train()

你可能感兴趣的:(pytorch学习,pytorch,python,深度学习)