pytorch节省显存小技巧

使用pytorch进行文本多分类问题时遇到了显存out of memory的情况,实验了多种方法,主要比较有效的有两种:
1、尽可能使用inplace操作,比如relu可以使用inplace=True
进一步将BN归一化和激活层Relu打包成inplace,在BP的时候再重新计算。
代码与论文参考
mapillary/inplace_abn
efficient_densenet_pytorch
效果:可以减少一半的显存
原理:
在大多数的深度网络的前向传播中,都有BN-Activation-Conv这样的网络结构,就必须要存储归一化的输入和全卷机层的输入。这是有必要的,因为反向传播需要输入计算梯度。通过重写BN的反向传播步骤,使用ABN代替BN-Activation序列,可以不用存储BN的输入(可通过激活函数的输出即全卷机层的输入反推),节省50%的显存。
2、使用float16精度混合计算,利用NVIDIA 的apex,也能减少50%的显存。但是有一些操作不安全如mean、sum等
官方代码参考
NVIDIA apex
3、pytorch1.0提供了模型拆分成2部分在2张卡上运行的方案
pytorch官网多卡例子
4、使用pytorch1.0的checkpoint特性,可以减少90%的显存
ckeckpoint通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算。ckeckpoint不会保存中间的激活参数,而是通过反向传播时重新计算他们。
具体可见我的GitHub一个文本分类项目
我的GitHub

你可能感兴趣的:(NLP)