Pytorch模型后缀

Pytorch是深度学习领域中非常流行的框架之一,支持的模型保存格式包括.pt和.pth .bin .onnx。

模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

模型代码:

(1)包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
(2)包含了我们如何定义的训练过程,包括epoch batch_size等参数;
(3)包含了我们如何加载数据和使用;
(4)包含了我们如何测试评估模型。

模型参数:
提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。

(1)包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
(2)包含optimizer.state_dict(),这是模型的优化器中的参数;
(3)包含我们其他参数信息,如epoch/batch_size/loss等。

数据集:

(1)包含了我们训练模型使用的所有数据;
(2)可以提示对方如何去准备同样格式的数据来训练模型。

使用文档:

(1)根据使用文档的步骤,每个人都可以重现模型;
(2)包含了模型的使用细节和我们相关参数的设置依据等信息。
可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。

现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。

为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save: Saves a serialized object to disk. This function uses
Python’s pickle utility for serialization. Models, tensors, and
dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取。

torch.load: Uses pickle’s unpickling facilities to deserialize pickled
object files to memory. This function also facilitates the device to
load the data into (see Saving & Loading Model Across Devices).

顺便提一下,“一切皆文件”的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。

你可能感兴趣的:(pytorch,人工智能,python)