神经网络训练过程中保存参数与加载参数

神经网络的训练往往需要一定的时间,如果训练过程中需要临时中断,其训练参数的保存与重新加载显得至关重要。

模型参数的保存:

# 模型参数保存,model是网络模型,optimizer是优化器
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, '文件名.pth')

模型参数的加载:

checkpoint=torch.load('文件名.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("加载成功!")

或者只保留模型参数,不保留优化器参数:

# 参数保存
torch.save(model.state_dict(), '文件名.pth')

# 参数加载
model.load_state_dict(torch.load('文件名.pth'))
print("加载成功!")

你可能感兴趣的:(深度学习,深度学习,python)