错误:
Traceback (most recent call last):。
File "main.py", line 103, in train
for batch_i, (images, labels) in enumerate(train_loader, start=1):
File "/u。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 310, in __iter__
return DataLoaderIter(self)
File "/us。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 180, in __init__
self._put_indices()
File "/us。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 219, in _put_indices
indices = next(self.sample_iter, None)
File "/u。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/sampler.py", line 119, in __iter__
for idx in self.sampler:
File "/u。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/sampler.py", line 50, in __iter__
return iter(torch.randperm(len(self.data_source)).long())
File "/us。。/anaconda2/lib/python2.7/site-packages/torch/utils/data/dataset.py", line 16, in __len__
raise NotImplementedError
NotImplementedError
未实现的错误
再看train_loader部分:
代码如下
train_loader = torch.utils.data.DataLoader(dataset=coco_train, batch_size=opt.batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
coco_train = CocoDataset(opt, vocab=vocab, imagef='resizedtrain2014',
annotf='coco/annotations/captions_train2014.json', transform=transform)
CocoDataset类定义如下:class CocoDataset(data.Dataset):
"""COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
def __init__(self, opt, vocab, imagef, annotf=None, transform=None):
。。。
def __getitem__(self, index):
"""Returns one data pair (image and caption)."""
。。。
return image, img_id
补上:
def __len__(self): return len(self.。。。)
重新运行,错误消失
原因是torch.utils.data.DataLoader中传入的参数dataset有要求,需要def __len__(self)函数。