torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
2.1 作用
对数据进行预处理,输出可以用来进行训练的数据张量。
2.2 参数
dataset:传入torch.utils.data.Dataset类的一个实例。
batch_size:mini batch的大小,应为int型。
shuffle:True的话是指数据是否会被随机打乱,默认False。
sampler:自定义的采样器(shuffle=True时会构建默认的采样器,如果想使用自定义的方法需要构造一个torch.utils.data.Sampler的实例来进行采样,并设置shuffle=False,将实例作为参数传入),返回一个数据数据的下标索引。
batch_sampler:和sampler类似,不过batch_sampler返回的是一个mini batch的数据索引,而sampler返回的是下标索引。
num_workers:dataloader使用的进程数目,应为int型。
collect_fn:传入一个自定义的函数,定义如何把一批Dataset的实例转换为包含迷你批次的数据张量,例如这里是YoloV3里的collect_fn:
def yolo_dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = np.array(images)
return images, bboxes
pin_memory:True的话会把数据转移到和GPU内存相关联的CPU内存中,从而能够加快GPU载入数据的速度。
drop_last:设置为True的话,当batch_size不能整除dataset里的数据总数时,会将最后一个batch抛弃,也就是说每一个batch都严格等于batch_size。
timeout:值如果大于零,就会决定在多进程情况下对载入数据的等待时间。
worker_init_fn:决定了每个子进程开始时运行的函数,这个函数运行在随机种子设置以后、载入数据之前。
multiprocessing_context:官方文档暂时未给出。
generator:如果不是none,这个随机数发生器将用来生成随机索引和多进程。(官方文档翻译过来的)
prefetch_factor:每个进程开始之前预加载的sample数。
persistent_workers:如果设置为True,dataloader不会在数据集被使用一次后关闭工作进程。
2.3 使用方法
在创建好了一个DataLoader的实例之后,需要利用for循环来读取批量数据,在循环中进行每个batch的训练:
gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate)
# 第一个参数train_dataset 为Dataset类的一个实例
for iteration, batch in enumerate(gen):
# for循环里为每个batch的训练内容
Dataloader每次循环时,先使用Dataset里的__getitem__方法获取batchsize个数据(也就是上面代码for循环里的batch),再使用collect_fn函数对batch做一些自定义的操作。
参考《深入浅出PyTorch》张校捷