PyTorch预训练模型保存与加载

模型的保存

'''
torch.save()函数保存序列化的对象。
'''

# 保存整个模型
torch.save(model, './path')

# 仅保存参数
torch.save(model.state_dict(), './path')

模型的加载

# 整个模型的加载
model = torch.load('./path')

# 获取权重值
checkpoint = model.state_dict()

当使用pytorch自己训练了一个模型并保存,下次想要直接加载使用时,必须清楚这个模型结构的所有内容来自PyTorch自带函数,还是有自定义的部分。若有自定义的部分则必须在使用它之前import或者写好自定义的部分,意即给出自定义的层、model类等。比如:

from TheModelByYourself import Layer1, Layer2, Function1, Function2

另外,model.state_dict()里面仅有定义为可训练的参数,可以自己打印出来看一下。
想要保存额外的参数,可以在保存时自定义保存内容,比如:

torch.save(	
			{'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss}, 
            './path'
          )

你可能感兴趣的:(机器学习)