CUDA out of memory 解决办法

原文链接 https://www.zhihu.com/question/274635237,我在文本检测 test 的时候遇到了 CUDA out of memory ,使用了 torch.no_grad() 和 torch.cuda.empty_cache() 减少显存的使用,效果明显。

一般,占用显存分为三部分:

  1. 网络模型自身参数占用的显存。
  2. 模型计算时(包括forward/backward/optimizer)所产生的中间变量或参数也有占用显存。
  3. 编程框架自身一些额外的开销。

改变网络结构

  1. 减少 batch_size …

  2. 牺牲计算速度减少显存占用量,将计算分为两半,先计算一半模型的结果,保存中间结果再计算后面一半的模型

    # 输入
    input = torch.rand(1, 10)
    # 假设我们有一个非常深的网络
    layers = [nn.Linear(10, 10) for _ in range(1000)]
    model = nn.Sequential(*layers)
    output = model(input)
    
    ### 可进行如下更改
    # 首先设置输入的input=>requires_grad=True
    # 如果不设置可能会导致得到的gradient为0
    
    input = torch.rand(1, 10, requires_grad=True)
    layers = [nn.Linear(10, 10) for _ in range(1000)]
    
    
    # 定义要计算的层函数,可以看到我们定义了两个
    # 一个计算前500个层,另一个计算后500个层
    
    def run_first_half(*args):
        x = args[0]
        for layer in layers[:500]:
            x = layer(x)
        return x
    
    def run_second_half(*args):
        x = args[0]
        for layer in layers[500:-1]:
            x = layer(x)
        return x
    
    # 我们引入新加的checkpoint
    from torch.utils.checkpoint import checkpoint
    
    x = checkpoint(run_first_half, input)
    x = checkpoint(run_second_half, x)
    # 最后一层单独调出来执行
    x = layers[-1](x)
    x.sum.backward()  # 这样就可以了
    
  3. 使用 pooling,减小特征图的 size

  4. 减少全连接层的使用

不修改网络结构

  1. 尽可能使用 inplace 操作,比如 relu 可以使用 inplace=True ,一个简单的使用方法,如下:

    def inplace_relu(m):
        classname = m.__class__.__name__
        if classname.find('Relu') != -1:
            m.inplace=True
    model.apply(inplace_relu)
    
  2. 每次循环结束时删除 Loss,可以节省很少显存,但聊胜于无

  3. 使用 float16 精度混合计算,可以节省将近 50% 的显存,但是要小心一些不安全的操作,如 mean 和 sum,溢出 fp16

  4. 对于不需要 bp 的 forward,如 validation,test 请使用 torch.no_grad(),注意 model.eval() 不等于 torch.no_grad()

    • model.eval()将通知所有图层您处于 eval 模式,这样,batch norm 或 dropout 图层将在 eval 模式下而不是训练模式下工作。
    • torch.no_grad()影响 autograd 并将其停用。 它将减少内存使用量并加快计算速度,无法进行反向传播 ( 在 eval 脚本中不需要 ) 。
  5. torch.cuda.empty_cache() 是 del 的进阶版。

  6. optimizer 的变换使用,理论上 sgd < momentum < adam,可以从计算公式中看出有额外的中间变量

  7. Depthwise Convolution

  8. 不要一次性把数据加载进来,而是部分地读取,这样就基本不会出现内存不够的问题

你可能感兴趣的:(Linux+工具)