Pytorch中的DataLoader内存泄漏导致RAM爆炸

最近跑一个新模型,但是刚开始跑一个epoch,就爆出了CUDA的OOM,看了一眼是RAM占用满了。
一开始很懵逼,后面用memory profiler来检查一下,发现内存占用直线上升。
Pytorch中的DataLoader内存泄漏导致RAM爆炸_第1张图片
到GitHub上提issue,作者让我用larger RAM…(只能说钱多任性)
自己肯定得想办法解决,于是接着用memory profiler分析了每行代码和objgraph查看各变量内存占用情况。最后定位在DataLoader有问题。
因为在代码中用了两个DataLoader,所以用了cycle库。源代码如下:

        if self.mode == 'supervised':
            dataloader = iter(self.supervised_loader)
            tbar = tqdm(range(len(self.supervised_loader)), ncols=135)
        else:
            dataloader = iter(zip(cycle(self.supervised_loader), cycle(self.unsupervised_loader)))
            tbar = tqdm(range(self.iter_per_epoch), ncols=135)

而,如果都添加了cycle的话,会有内存泄露的问题,只需要对数据量较小的Dataloader添加cycle就行。如下图

        if self.mode == 'supervised':
            dataloader = iter(self.supervised_loader)
            tbar = tqdm(range(len(self.supervised_loader)), ncols=135)
        else:
            # dataloader = iter(zip(cycle(self.supervised_loader), cycle(self.unsupervised_loader)))
            dataloader = iter(zip(cycle(self.supervised_loader), self.unsupervised_loader))
            tbar = tqdm(range(self.iter_per_epoch), ncols=135)

问题顺利解决了。

你可能感兴趣的:(pytorch,pytorch)