关于pytorch使用多个dataloader并使用zip和cycle来进行循环时出现的显存泄漏的问题

关于pytorch使用多个dataloader并使用zip和cycle来进行循环时出现的显存泄漏的问题

如果我们想要在 Pytorch 中同时迭代两个 dataloader 来处理数据,会有两种情况:一是我们按照较短的 dataloader 来迭代,长的 dataloader 超过的部分就丢弃掉;二是比较常见的,我们想要按照较长的 dataloader 来迭代,短的 dataloader 在循环完一遍再循环一遍,直到长的 dataloader 循环完一遍。

两个dataloader的写法及问题的出现

第一种情况很好写,直接用 zip 包一下两个 dataloader 即可:

# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(zip(dataloaders1, dataloaders2)):
        print(data)
        # 开始写你的训练脚本

第二种情况笔者一开始时参考的一篇博客的写法,用 cycle 将较短的 dataloader 包一下:

from itertools import cycle
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
        print(data)
        # 开始写你的训练脚本

是可以运行,但是这样出现了明显显存泄漏的问题,在笔者自己的实验中,显存占用量会随着训练的进行,每轮增加 20M 左右,最终导致显存溢出,程序失败。

解决方法

笔者找了半天,终于在 StackOverflow 的一篇贴子中找到了解决方法,该贴的一个答案指出:cyclezip 的方法确实可能会造成显存泄漏(memory leakage)的问题,尤其是在使用图像数据集时,可以通过以下写法来迭代两个 dataloader 并避免这个问题:

# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    dataloader_iterator1 = iter(dataloaders1)
    
    for i, data2 in enumerate(dataloaders2):

        try:
            data1 = next(dataloader_iterator1)
        except StopIteration:
            dataloader_iterator1 = iter(dataloaders1)
            data1 = next(dataloader_iterator1)
        print(data1, data2)

        # 开始你的训练脚本

笔者亲测这种方式是可以正常运行且不会有显存泄漏问题的。

你可能感兴趣的:(PyTorch,issues,pytorch,python)