Pytorch模型保存/加载方式:①只保存/加载模型参数【推荐】;②保存/加载整个模型(结构+参数);③保存模型Checkpoint;④CPU/GPU保存加载【后缀:pt、pth、pkl】

当提到保存和加载模型时,有三个核心功能需要熟悉:

  1. torch.save:将序列化的对象保存到disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。
  2. torch.load:使用pickle unpickle工具将pickle的对象文件反序列化为内存。
  3. torch.nn.Module.load_state_dict:使用反序列化状态字典加载model’s参数字典。

一、模型保存与调用方式一:只保存模型参数

1、模型保存

model = TheModelClass(*args, **kwargs)
# ------------- 模型训练: 开始 -------------
....

你可能感兴趣的:(#,Pytorch,pytorch,python,神经网络,模型保存,模型加载)