参考翻译 SAVING AND LOADING MODELS
三个核心函数:
torch.save
将序列化对象保存到磁盘
该函数使用 pickle
进行序列化
包括 models,tensors,dictionaries 等所有类型对象都可以使用该函数保存
torch.load
反序列化加载到内存
torch.nn.Module.load_state_dict
使用反序列化的 state_dict
加载模型参数字典
torch.nn.Module
的可学习化参数(learnable parameters, eg. weights 和 bias) 都包含在模型的参数中(保存在 model.parameters()
)
state_dict
只是一个Python dictionary
对象,它将每个层映射到它的参数张量
需要注意的是,只有拥有可学习(learnable parameters)参数的神经网络层(eg. convolutional layers, linear layers)和注册的缓存(batchnorm的running_mean)才在 state_dict
中有条目。
优化器对象(torch.optim
)也有 state_dict
,其中包含优化器状态的相关信息以及使用的超参数。
因为 state_dict
是 Python dictionaries
,所以它们可以很容易的保存,替换,更新,添加。
print("Model's state_dict")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Optimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Save:
torch.save(model.state_dict(), PATH)
Load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
在预测之前需要调用
model.eval()
来设置dropout
和batch normalization layers
为evaluation
模式。
一般使用.pt
或pth
后缀保存模型文件
Save:
torch.save(model, PATH)
Load:
model = torch.load(PATH)
model.eval()
使用该方式保存模型,缺点是序列化的数据依赖于特定的类和额外的数据结构。而
pickle
不保存模型的类本身,而是保存包含这个类的文件的位置。
因此,如果进行代码重构的话,会出现问题。
在训练中,每隔 M 个 epoch 保存一次模型,避免训练中断,以恢复模型
Save:
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict:" optimizer.state_dict(),
"loss": loss,
...
}, PATH)
Load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
model.eval()
# or train after break
modle.train()
一般使用
.tar
文件后缀保存 checkpoint 文件
Save:
torch.save(modelA.state_dict(), PATH)
Load:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
迁移学习
设置strict=False
来忽略不匹配的键值
除此以外,也可以只加载某些匹配的神经网络层的参数
torch.save(model.state_dict(), PATH)
Load:
device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH), map_location=device)
在 CPU 上调用 GPU 上训练的模型,传递参数
map_location=device
,则重新将张量动态的映射到 CPU 上
torch.save(model.state_dict(), PATH)
Load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
注意,这里需要将模型输入的其他张量
调用input = input.to(device)
调用会返回一个在 GPU 上input
的新的拷贝
而不会重写input
,所以需要重新赋值