在pytorch中若是使用自定义数据集,需要定义Dataset类,并覆盖父类的__len__和__getitem__
函数
举个例子,返回常规的数据对x, y 也可以是多个x,y 比如小样本学习中需要query support对就是两个x,两个y
class MyDataset(Dataset):
'''
定义相关数据
'''
def __len__(self):
return len(self.x_data)
def __getitem__(self, idx):
'''
相关处理
'''
return x, y
但是在 __getitem__中也可以返回字典类型的数据 , 例如
def __getitem__(self, idx)
'''
省略
'''
batch = {'query_img': query_img,
'query_mask': query_mask,
'query_name': query_name,
'query_ignore_idx': query_ignore_idx,
'support_imgs': support_imgs,
'support_masks': support_masks,
'support_names': support_names,
'support_ignore_idxs': support_ignore_idxs,
'class_id': torch.tensor(class_sample)}
return batch
下面解释一下为什么可以返回字典.
通常当我们定义好Dataset并实例化dataset之后,会实例化一个DataLoader并将dataset传入其中,DataLoader的作用是拼接多个__getitem__获得的数据,返回一个batch的数据,在实例化DataLoader的时候有一个参数是collate_fn,它用来定义数据batch拼接方式
#参数解释
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
再来看一下默认的collate_fn函数是如何定义的
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
起作用的应该是是这一行
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
其中elem是batch中的第一个元素,用列表循环式把batch中所有相同key的数据添加到同一个key的[]中
再来看一下collate_fn被调用的地方
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)