Python Bug 关于PyTorch-Dataset 内存持续增长bug

太长不看

List[Dict] 对Dict添加数据,如果数据仅仅时临时数据(希望函数结束自动释放), 不会的,会一直保存。
很简单, 但是如果隐藏在复杂的程序逻辑中,可能不是很好发现。仅以此篇警戒。

bug 重现

import os
import psutil
from torch.utils.data import Dataset, DataLoader


class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.datas = []
        # 为了方便重现 - 对数据的模拟
        for i in range(1000):
            self.datas.append({"idx": i})
        self.length = len(self.datas)

    def __getitem__(self, index):
        _sample = self.datas[index]
        # 处理数据
        _sample["new"] = "new" * 100000  # bug reason
        return _sample

    def __len__(self):
        return self.length


dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=64)   # noqa: E221

process = psutil.Process(os.getpid())
for i, batch in enumerate(loader):
    mm_info = process .memory_full_info()
    print(i, mm_info.uss / 1024 / 1024, "MB")

运行: 将上述内容保存到文件,如main.py, python main.py

$ python main.py
0 216.359375 MB
1 234.67578125 MB
2 252.9921875 MB
3 271.30859375 MB
4 289.62890625 MB
5 307.9453125 MB
6 326.2578125 MB
7 344.57421875 MB
8 362.88671875 MB
9 381.203125 MB
10 399.515625 MB
11 417.83984375 MB
12 436.15234375 MB
13 454.46875 MB
14 472.78125 MB
15 484.2265625 MB

可以看到程序使用的内存在持续增长。
深度学习的数据量还是很大的, 因此很快就会用完内存,报错(Killed)退出。

debug

猜想: 猜想就是有用完的内存,没有释放掉。
可以使用memory_profiler库,观察内存的变化。

在查找问题时, 还发现一些其他的内存泄漏问题, 放在这里供参考。
(本文的问题不时问题泄漏, 是Bug)

  • pytorch内存泄露-dataloader - 知乎 (zhihu.com)
  • pytorch内存泄漏分析案例 | list转tensor - 知乎 (zhihu.com)
  • pytorch中碰到的memory leak问题 - 简书 (jianshu.com)

原因

__getitem__ 使用index取数据,又向数据中写入了新的数据(是临时数据),相当于间接向self.datas写入数据,__getitem__结束时,并不会释放新的临时数据,会一直保存下来,这就是内存一直增长的原因.

deepcopy 修正

知道了原因,就很好修改了. 在__getitem__中新建数据。脱离与self.datas的联系.
直接deepcopy生成新的临时数据,仅修改新的临时数据.
下面展示一些 内联代码片

from copy import deepcopy
	...
    def __getitem__(self, index):
        _sample = self.datas[index]
        ...
    # 修改为
    def __getitem__(self, index):
        _sample = self.datas[index]
        _sample = deepcopy(_sample)

修改后的结果

(pl) qzhao@cu12:~/codes/nes$ python mm.py
0 215.640625 MB
1 233.95703125 MB
2 215.76953125 MB
3 233.9609375 MB
4 215.76953125 MB
5 233.9609375 MB
6 215.7734375 MB
7 233.96484375 MB
8 215.7734375 MB
9 233.96484375 MB
10 215.78125 MB
11 233.97265625 MB
12 215.78125 MB
13 233.97265625 MB
14 215.78125 MB
15 227.10546875 MB

内存使用在跳动,不再时持续增长。

你可能感兴趣的:(python,pytorch,bug)