【问题探究】如何解决pytorch训练时的显存占用递增(导致out of memory)

前言:

{

    现在的神经网络模型,动不动就爆内存。两年前我笔记本2G的显存都绰绰有余,现在16G的P100,24G的P40却还不够。更让我郁闷的是,在pytorch训练时,显存占用竟然会不断增加,可能刚开始训练时是正常的,但是放在那里,不知道什么时候它就突然来一句out of memory,然后就尥蹶子不干了,白白浪费了很长的时间。所以这个问题我确实需要搞清楚。

}

 

正文:

{

    首先,我要说一个比较野蛮的办法,就是单独写一个训练脚本,其开始时先载入模型,结束时再保存模型。然后把数据集分割成更小的子数据集(小到模型不会因为显存而尥蹶子不干)。当然,训练脚本的输出参数应当包含数据集(编号)和/或子数据集(编号)。

 

    我去谷歌上搜了一下,最先看到的是[1],上面建议用del删除一些变量,我尝试过用del在每次迭代后删除所有能删除的变量(输入,输出,损失),但是不起效果,模型还是会在同样的迭代次数后报错。

 

    后来我又找到了[2],上面说之后再加上torch.cuda.empty_cache(),这次成功了。

    也就是说,del操作后再加上torch.cuda.empty_cache()才会起效果!代码1是一个例子。

#代码1。
"""添加了最后两行,img和segm是图像和标签输入,很明显通过.cuda()已经是被存在在显存里了;
   outputs是模型的输出,模型在显存里当然其输出也在显存里;loss是通过在显存里的segm和
   outputs算出来的,其也在显存里。这4个对象都是一次性的,使用后应及时把其从显存中清除
   (当然如果你显存够大也可以忽略)。"""

def train(model, data_loader, batch_size, optimizer):

    model.train()
    total_loss = 0
    accumulated_steps = 32 // batch_size
    optimizer.zero_grad()
    for idx, (img, segm) in enumerate(tqdm(data_loader)):
        img = img.cuda()
        segm = segm.cuda()
        outputs = model(img)
        loss = criterion(outputs, segm)
        (loss/accumulated_steps).backward()
        if (idx + 1 ) % accumulated_steps == 0:
            optimizer.step() 
            optimizer.zero_grad()
        total_loss += loss.item()
        
        # delete caches
        del img, segm, outputs, loss
        torch.cuda.empty_cache()

    至于为什么不能直接使用empty_cache(),我在官方文档中找到了[3],上面说empty_cache()不会释放还被占用的内存。所以这里使用了del让对应数据成为“没标签”的垃圾,之后这些垃圾所占的空间就会被empty_cache()回收。

 

    另外,[4]中提到了一些其他方法,例如查看所有tensor对象来确定是否有泄漏,我试了一下,发现打印太多(一个语义分割模型有4000多行的对象打印输出。。。),很难直接找到问题所在。

}

 

结语:

{

    我感觉我又回到了c语言内存泄漏的时代,我觉得这其实比较考验程序员的心思缜密程度,细心也算是程序员的必备属性(之前听过一档音频节目,我记得主持人说程序员其实和很久以前的电报员有些类似,最终都会变成女性比较适合的职业。现在就细心看来确实有点类似)。总之各位多多努力吧。

    不知道能不能实现自动回收,不过我感觉自动回收的前提是有比较充足的硬件资源,因为毕竟自动回收的效率没手动回收高。

    参考资料:

    {

        [1] https://discuss.pytorch.org/t/cuda-memory-continuously-increases-when-net-images-called-in-every-iteration/501

        [2] https://forums.fast.ai/t/clearing-gpu-memory-pytorch/14637/3

        [3] https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management

        [4] https://discuss.pytorch.org/t/how-to-debug-causes-of-gpu-memory-leaks/6741/3

    }   

}

你可能感兴趣的:(问题探究,python,神经网络实践,pytorch,神经网络,机器学习)