torch.save
torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)
torch.load
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
map_location
选择加载到CPU或GPU中
# 保存在 CPU, 加载到 GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# 保存在 GPU, 加载到 CPU
model.load_state_dict(torch.load(PATH, map_location='cpu'))
model.load_state_dict()
model.load_state_dict(state_dict, strict=True)
# 保存
torch.save(model, PATH)
# 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
state_dict
保存加载(推荐)使用state_dict
只保留了权重参数,因此在加载时需要先初始化模型
否则会出现 pytorch AttributeError 报错
保存
torch.save(model.state_dict(), PATH)
加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval() #一定要初始化 不然会报错
一般保存为.pt
或.pth
格式的文件。
1.load_state_dict()函数需要一个 dict 类型的输入,而不是保存模型的 PATH。所以这样 model.load_state_dict(PATH)是错误的,而应该model.load_state_dict(torch.load(PATH))。
2.如果你想保存验证机上表现最好的模型,那么这样best_model_state=model.state_dict()是错误的。因为这属于浅复制,也就是说此时这个 best_model_state 会随着后续的训练过程而不断被更新,最后保存的其实是个 overfit 的模型。所以正确的做法应该是best_model_state=deepcopy(model.state_dict())。
保存和加载 state_dict (没有训练完,还会继续训练)
保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...你自己的参数
}, PATH)
加载
model = XIAOHU(*args, **kwargs)
optimizer = adam(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
...
model.eval()
# - or -
model.train()