如果我们想要在 Pytorch 中同时迭代两个 dataloader 来处理数据,会有两种情况:一是我们按照较短的 dataloader 来迭代,长的 dataloader 超过的部分就丢弃掉;二是比较常见的,我们想要按照较长的 dataloader 来迭代,短的 dataloader 在循环完一遍再循环一遍,直到长的 dataloader 循环完一遍。
第一种情况很好写,直接用 zip
包一下两个 dataloader 即可:
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
for i, data in enumerate(zip(dataloaders1, dataloaders2)):
print(data)
# 开始写你的训练脚本
第二种情况笔者一开始时参考的一篇博客的写法,用 cycle
将较短的 dataloader 包一下:
from itertools import cycle
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
for i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
print(data)
# 开始写你的训练脚本
是可以运行,但是这样出现了明显显存泄漏的问题,在笔者自己的实验中,显存占用量会随着训练的进行,每轮增加 20M 左右,最终导致显存溢出,程序失败。
笔者找了半天,终于在 StackOverflow 的一篇贴子中找到了解决方法,该贴的一个答案指出:cycle
加 zip
的方法确实可能会造成显存泄漏(memory leakage)的问题,尤其是在使用图像数据集时,可以通过以下写法来迭代两个 dataloader 并避免这个问题:
# ...
dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
dataloader_iterator1 = iter(dataloaders1)
for i, data2 in enumerate(dataloaders2):
try:
data1 = next(dataloader_iterator1)
except StopIteration:
dataloader_iterator1 = iter(dataloaders1)
data1 = next(dataloader_iterator1)
print(data1, data2)
# 开始你的训练脚本
笔者亲测这种方式是可以正常运行且不会有显存泄漏问题的。