pytorch----模型参数的保存与加载

网络模型的保存与加载
1.单纯保存网络模型参数,一条语句即可

torch.save(model.state_dict(), path)
# 其中path=’./model.pth’ , path=’./model.tar’, path=’./model.pkl’
# 保存参数的文件一定要有后缀扩展名。

model.load_state_dict(torch.load(path))

2.还想保存训练采用的优化器、epoch信息等

state = {
		'model': model.state_dict(), 
		'optimizer': optimizer.state_dict(), 
		'epoch': epoch
		}
torch.save(state, path)
# 加载
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint(['epoch'])

你可能感兴趣的:(计算机视觉,深度学习,pytorch)