PyTorch笔记之模型保存和加载

文章目录

    • Saving and Loading Models
      • Before
      • 关于 state_dict
      • 保存/加载 state_dict
      • 保存/加载 整个模型
      • 保存/加载 Checkpoint 以及恢复训练
      • 使用来自不同模型的参数 Warmstarting 模型
      • 跨 GPU 和 CPU 保存和加载模型

Saving and Loading Models

参考翻译 SAVING AND LOADING MODELS

Before

三个核心函数:

  • torch.save
    将序列化对象保存到磁盘
    该函数使用 pickle 进行序列化
    包括 models,tensors,dictionaries 等所有类型对象都可以使用该函数保存

  • torch.load
    反序列化加载到内存

  • torch.nn.Module.load_state_dict
    使用反序列化的 state_dict 加载模型参数字典

关于 state_dict

torch.nn.Module 的可学习化参数(learnable parameters, eg. weights 和 bias) 都包含在模型的参数中(保存在 model.parameters()

state_dict只是一个Python dictionary对象,它将每个层映射到它的参数张量

需要注意的是,只有拥有可学习(learnable parameters)参数的神经网络层(eg. convolutional layers, linear layers)和注册的缓存(batchnorm的running_mean)才在 state_dict 中有条目。

优化器对象(torch.optim)也有 state_dict,其中包含优化器状态的相关信息以及使用的超参数。

因为 state_dictPython dictionaries ,所以它们可以很容易的保存,替换,更新,添加。

print("Model's state_dict")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Optimizer's state_dict")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

保存/加载 state_dict

Save:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))
model.eval()

在预测之前需要调用 model.eval() 来设置 dropoutbatch normalization layersevaluation 模式。
一般使用 .ptpth 后缀保存模型文件

保存/加载 整个模型

Save:

torch.save(model, PATH)

Load:

model = torch.load(PATH)
model.eval()

使用该方式保存模型,缺点是序列化的数据依赖于特定的类和额外的数据结构。而 pickle 不保存模型的类本身,而是保存包含这个类的文件的位置。
因此,如果进行代码重构的话,会出现问题。

保存/加载 Checkpoint 以及恢复训练

在训练中,每隔 M 个 epoch 保存一次模型,避免训练中断,以恢复模型

Save:

torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict:" optimizer.state_dict(),
    "loss": loss,
    ...
    }, PATH)

Load:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

model.eval()
# or train after break
modle.train()

一般使用 .tar 文件后缀保存 checkpoint 文件

使用来自不同模型的参数 Warmstarting 模型

Save:

torch.save(modelA.state_dict(), PATH)

Load:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

迁移学习
设置 strict=False 来忽略不匹配的键值
除此以外,也可以只加载某些匹配的神经网络层的参数

跨 GPU 和 CPU 保存和加载模型

  • 在 GPU 上保存,在 CPU 上加载
    Save:
torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cpu")
model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH), map_location=device)

在 CPU 上调用 GPU 上训练的模型,传递参数 map_location=device,则重新将张量动态的映射到 CPU 上

  • 在 GPU 上保存,在 GPU 上调用
    Save:
torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")

model = TheModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))
model.to(device)

注意,这里需要将模型输入的其他张量
调用 input = input.to(device)
调用会返回一个在 GPU 上 input新的拷贝
而不会重写 input,所以需要重新赋值

你可能感兴趣的:(PyTorch)