pytorch基础知识整理(三)模型保存与加载

1, torch.save(); troch.load()

torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型、张量、字典等,文件后缀名一般用pth或pt或pkl。torch.load()使用python的pickle模块实现从磁盘加载。可以用此来直接保存或加载完整模型:

torch.save(model, 'PATH.pth')
model = torch.load('PATH.pth')

注意:pytorch1.6以后保存的模型使用zip压缩,所以保存的模型无法被1.6以前的版本加载,如果要跨版本使用,需要做以下修改

torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)

2, .state_dict(); .load_state_dict()

模型的框架已经在程序代码中了,因此训练好的模型只需要保存模型的参数即可供推理使用。model.state_dict()以字典的形式保存模型的参数,字典的键是参数名,值是参数值的张量。得到状态字典后还需用torch.save()固化到磁盘。
除模型外,优化器optimizer也可以保存和加载状态字典。

torch.save(model.state_dict(), 'PATH.pth')
model.load_state_dict(torch.load('PATH.pth'))

注意在多卡GPU训练时,保存和加载模型需要在model后加上module,即

torch.save(model.module.state_dict(), 'PATH.pth')
model.module.load_state_dict(torch.load('PATH.pth'))

3, 保存checkpoint

如果是训练中途保存用于继续训练,就不仅要保存权重参数,还要保存当前epoch,优化器的状态,当前的损失值等,可以统一打包到一个字典中保存为checkpoint,此时文件后缀名一般用tar。

#保存:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
##加载:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

你可能感兴趣的:(随笔·各种知识点整理,深度学习)