PyTorch模型保存和加载以便继续训练

通用的PyTorch模型保存和加载模板

保存模型

state = {
      'model': model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch': epoch }  
torch.save(state, path)

加载模型

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']+1

你可能感兴趣的:(PyTorch,深度学习,pytorch,机器学习)