Dataset负责生产数据,DataLoader负责数据的分批(batch_size)、采样(sampler)、传输
Pytorch版本:1.0.1
继承torch.utils.data.Dataset,实现两个函数即可:
将Dataset作为参数,构造一个torch.utils.data.DataLoader对象即可。
DataLoader其他参数见下文。
dataset(Dataset):
传入的数据集
batch_size(int, optional):
每个batch有多少个样本
shuffle(bool, optional):
在每个epoch开始的时候,对数据进行重新打乱
sampler(Sampler, optional):
自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
batch_sampler(Sampler, optional):
与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional):
这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
collate_fn (callable, optional):
将一个list的sample组成一个mini-batch的函数
pin_memory (bool, optional):
如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
drop_last (bool, optional):
如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
timeout(numeric, optional):
如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
worker_init_fn (callable, optional):
每个worker初始化函数 If not None, this will be called on each
worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
DataLoader构造函数中相关代码:
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset) ##如果shuffer就随机
else:
sampler = SequentialSampler(dataset) ##否则顺序采样
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
batch_sampler是sampler的封装,可自定义批次数据的构造。默认BatchSampler相关源码:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx) ##遍历sampler获取数据,满batch_size就yield
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
例如cirtorch中将数据拆成input_data和target两个数据。
因Dataset中get_item返回input_data和target两个值,如果不用该函数,每个batch的数据应该是[batch_size,2(先input_data再target),],经过该函数将变成([batch_size,],[batch_size,]),第一个数据全是input_data,第二个数据全是target。
pin_memory可在cpu主存(内存)中分配不可交换到swap(缓存)的内存。。默认内存分配中的数据都可交换到swap中,那CUDA驱动会通过DRAM机制将数据从内存传到GPU显存时会复制2次(先复制到一临时不可见pinned固定内存,再往显存中复制),因此pin_memory=True可提高约2倍cpu到gpu传输效率(.cuda()或 .to(device)的时候)。相见CPU和GPU内存交互。
【拓展】Elasticsearch中的Memlock(内存锁定)可申请固定大小且不可交换内存空间。
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue() # 1.每个子进程一个队列放要处理的下标
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_utils.worker._worker_loop, # 每个子进程循环执行的函数
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子进程公用要返回batch数据的队列
self.collate_fn, base_seed + i,
self.worker_init_fn, i))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self.workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
每个worker有一个index_queue dataloader.py#L544-L552
每个worker从index_queue取要处理的下标 dataloader.py#L124
dataloader输出一次数据前先往index_queue中放一次下标, _process_next_batch函数:
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices() ## 先放下一批数据下标
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch ## 再返回该批数据
_put_indices依次往不同worker所属的index_queue中放 dataloader.py#L644-L652
完整的dataloader next函数:
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch) ## 5. 之前以及取出来该下标数据,直接返回
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True: ## 1.直到取的数据下标正确才return
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch() ## 2.从worker_result_queue中获取数据
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch ## 3.下标不对先存一下
continue
return self._process_next_batch(batch) ## 4.内部先放下一批数据下标再返回batch数据
每个worker一直在执行的循环_worker_loop,其中worker_result_queue作为_worker_loop函数的data_queue传入(dataloader.py#L544-L552),相见:
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
global _use_shared_memory
_use_shared_memory = True
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##从index_queue中获取要处理的下标
except queue.Empty:
continue
if r is None:
# Received the final signal
assert done_event.is_set()
return
elif done_event.is_set():
# Done event is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根据下标取样本数据
except Exception:
# It is important that we don't store exc_info in a variable,
# see NOTE [ Python Traceback Reference Cycle Problem ]
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else: ## 2. 没有抛异常就将样本数据放入结果返回队列
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass