[pytorch]模型参数保存与加载

最简单的情况

模型保存:

torch.save(model.state_dict(), PATH)

模型加载:

model.load_state_dict(torch.load(PATH))

此时保存的是一个字典,key为model中的weight或bias名,如"linear1.weight"或“linear2.bias”

 

 

 

有时我们使用了优化器

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

我们在保存参数时需要同时保存优化器中的参数:

save_state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict()}
torch.save(save_state, PATH)

在加载时,

model=MyModel()

model.load_state_dict(torch.load("PATH")['net']) 
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
optimizer.load_state_dict(torch.load("lab3_lstmtest_0614.pth")['optimizer'])

这样即保存和加载了模型和优化器参数,继续上一次训练。

你可能感兴趣的:([pytorch]模型参数保存与加载)