torch.save
torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)
torch.load
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
map_location
选择加载到CPU或GPU中
# 保存在 CPU, 加载到 GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# 保存在 GPU, 加载到 CPU
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
model.load_state_dict()
model.load_state_dict(state_dict, strict=True)
state_dict
match the keys returned by this module’s state_dict()
function. Default: True
Model class must be defined somewhere
# 保存
torch.save(model, PATH)
# 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
state_dict
保存加载PyTorch 中,torch.nn.Module
里面的可学习的参数 (weights 和 biases) 都放在model.parameters()
里面。而 state_dict 是一个 Python dictionary object,将每一层映射到它的 parameter tensor 上。注意:只有含有可学习参数的层 (convolutional layers, linear layers),或者含有 registered buffers 的层 (batchnorm’s running_mean) 才有 state_dict。优化器的对象 (torch.optim
) 也有 state_dict,存储了优化器的状态和它的超参数。
使用state_dict
只保留了权重参数,因此在加载时需要先初始化模型
否则会出现 pytorch AttributeError 报错
保存
torch.save(model.state_dict(), PATH)
加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
一般保存为.pt
或.pth
格式的文件。
load_state_dict()
函数需要一个 dict 类型的输入,而不是保存模型的 PATH。所以这样 model.load_state_dict(PATH)
是错误的,而应该model.load_state_dict(torch.load(PATH))
。best_model_state=model.state_dict()
是错误的。因为这属于浅复制,也就是说此时这个 best_model_state 会随着后续的训练过程而不断被更新,最后保存的其实是个 overfit 的模型。所以正确的做法应该是best_model_state=deepcopy(model.state_dict())
。保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
保存
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
加载
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
有时候训练一个新的复杂模型时,需要加载它的一部分预训练的权重。即使只有几个可用的参数,也会有助于 warmstart 训练过程,帮助模型更快达到收敛。
如果手里有的这个 state_dict 缺乏一些 keys,或者多了一些 keys,只要设置strict
参数为 False,就能够把 state_dict 能够匹配的 keys 加载进去,而忽略掉那些 non-matching keys。
保存模型 A 的 state_dict :
torch.save(modelA.state_dict(), PATH)
加载到模型 B:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)