解决Pytorch加载模型时占用大量内存的问题

一、背景

公司项目中有一个文本分类任务用使用了深度学习模型,模型使用torch实现的,所有的训练和测试流程都是在GPU机器上进行的,所以保存的模型也是以GPU Tensor形式存储的。整个流程在GPU机器上都没有问题(机器内存64GB)。直到有一天把代码搬到没有GPU的机器上(内存16GB),进行预测任务,这时torch加载模型时爆内存了。。。

二、问题的解决

解决这个问题所谓时一波三折。

1. 优化数据集词表

最先想到的问题是数据集太大了,因为在GPU机器上发现程序最高内存占用到36GB,而且模型预测时需要加载词表(torchtext实现),貌似torchtext不能加载用户的自定义词表(刚接触还没有详细研究)?之前为了方便就直接加载训练集生成词表。所以这肯定是占用内存过高的一个原因。通过一顿操作,首先存储torchtext.vocab.itostorchtext.vocab.stoi的词表,之后预测时读取这两个映射,最后发现,内存的占用确实减少了一些,但是在没有GPU的机器上还是爆内存啊

2. 加载模型时添加map_location参数

想到这个参数的原因可能是在调试的时候发现每次读取模型时就会爆内存,也可能是我Google相关问题时无意间受到了启发。于是我决定看看torch.load()这个函数有没有参数可以降低读取模型时内存的占用,发现没有什么特别的参数,除了map_location(好像也可以通过io.BytesIO来读取模型,时间原因就没有研究),因为正常情况下,通过GPU训练的模型,如果想要在CPU上使用,确实需要在加载模型的时候加上参数map_location,代码示例如下:

model = SomeModelClass(**args)
model.load_state_dict(torch.load(model_path, map_location=troch.device('cpu')))

加上map_location参数之后,神奇的事情发生了,模型加载时内存的占用断崖式减小,从16GB减小到了2GB(不严谨统计)。
截止到我写这篇文章也没有理解map_location=troch.device('cpu')会降低内存的原因,时间关系我也没有搜索相关资料,反正torch加载模型时爆内存的问题解决了。欢迎知道原因的dalao科普或指正。

你可能感兴趣的:(解决Pytorch加载模型时占用大量内存的问题)