使用`checkpoint`进行显存优化的学习笔记

1 介绍

Checkpoint的主要原理是:在前向阶段传递到checkpoint中的forward函数会以 torch.no_grad 模式运行,并且仅仅保存输入参数和 forward 函数,在反向阶段重新计算其 forward 输出值。
(引用于《拿什么拯救我的 4G 显卡 | OpenMMLab》)

2 写作思路

  • 只在nn.Module的上层模块使用checkpoint,而不是在大模型的forward函数中写作;

3 示例代码

使用checkpoint的示例代码:

  • InvertedResidual_with_cp
  • HRFPN_with_cp
  • DetectoRS_ResNet_Bottleneck_with_cp (2020)

我们可以学习使用checkpoint进行显存优化,示例代码如下:

def forward(self, x):
    def _inner_forward(x):
        identity = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out
        
    # x.requires_grad 这个判断很有必要
    if self.with_cp and x.requires_grad:
        out = cp.checkpoint(_inner_forward, x)
    else:
        out = _inner_forward(x)
    out = self.relu(out)
    return out

你可能感兴趣的:(学习,with_cp)