碰到这种问题,尤其是平常运行的好好的,换个数据集就报错,那大概率就是数据集本身有问题。顺着这个思路去debug即可。
dataloader在设置num_workers为任何大于0的数时出现如下报错:
Traceback (most recent call last):
File "/home/username/distort/main.py", line 131, in <module>
model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)
File "/home/username/distort/main.py", line 65, in train_model
for img, y in train_dataloader:
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
data = self._next_data()
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
return self._process_data(data)
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
data.reraise()
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/_utils.py", line 461, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 140, in default_collate
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable
num_workers设置为0时则出现新的报错:
Traceback (most recent call last):
File "/home/username/distort/main.py", line 130, in <module>
model, perms, accs = train_model(dinfos, args.mid, args.pretrained, args.num_classes, args.treps, args.testep, args.test_dist, device, args.distort)
File "/home/username/distort/main.py", line 64, in train_model
for img, y in train_dataloader:
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
data = self._next_data()
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in default_collate
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 175, in <listcomp>
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "/home/username/miniconda3/envs/round11/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 141, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 64, 64] at entry 0 and [1, 64, 64] at entry 32
第二个报错还是比较容易排查的。在自定义dataset类的__getitem__()函数中加入代码:当读取的tensor的shape[0]为1时打印该tensor对应原始数据文件的路径。
发现数据集中确实有通道数为1的图片(我用的tiny-imagenet-200),没想到真的是数据集的锅。
在__getitem__()函数使用tensor类的expand,对于通道数不对的tensor,调用expand(3,-1,-1)
即可。之后num_workers设置为0或者其他正数时都能正常加载数据集。
另外需要注意,有的博客说num_workers需要匹配GPU核心的数量,这逻辑属实离谱。从上面的第一个报错就能看出来,出错点和CUDA库毫无关系,因此不可能是GPU相关的问题。至少按照常用的加载数据集的方法,num_workers就是规定dataloader使用CPU线程的最大数量。