我们自定义的Dataset类必须要实现:
class dataset(Dataset):
def __init__(self, corpus_path, sentence_max_length):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
可以看到自定义的dataset必须要有len()方法,和下标索引方法
所以我们并不是构建一个dataloader,必须要自定义一个Dataset类。
只要我们传入的数据集支持len,下标访问,且每次访问返回一个数据和标签就可以了,例如列表和数组。
Example:
dataset = [np.numpy(data),int(label)]
len(dataset)
data,label = dataset[index]
上面这个例子没有构建dataset类但是也满足Dataset类的所有特征,所以也可以直接加载到DataLoader中,但是可能在适配自己的模型时需要自己实现一下DataLoader中的collate_fn()函数,来对数据进行正确拼接。
至于是否需要进行自定义collate_fn ,主要看我们输入的数据是否为tensor格式, 如果内部元素是tensor格式,那么就不需要自己重新实现collate_fn, 如果内部元素不是tensor格式,就需要自己重新实现该函数。
关于如何构建collate_fn
: 从0构建一个collate_fn函数至于为什么,可以看下面的default_collate()源码。如果内部元素是tensor格式的话,则可以进入第一个if分枝语句,通过torch.stack()进行构建batch。 如果内部元素不是tensor格式,例如为元组或者,列表形式。那么就要进入最后一个else分枝语句块,通过执行transposed = zip(*batch),这产生的结果可能不是我们所期望的。
DataLoader(dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False,
sampler: Optional[Sampler[int]] = None,batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False)
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))