PyTorch模型训练最开始就是数据读取以及预处理模块,而该模块包括了两个重要的入口,第一个是用于将disk中的数据读取路径预处理好的ImageFolder
方法:
train_dataset = datasets.ImageFolder(traindir,transform_train)
第二个方法为真正做数据IO读取以及预处理的DataLoader
方法,而真正训练的时候也是将train_loader
传入到train()函数中,并且以迭代器的形式按照iter的形式取相对应的数据出来。
train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
下面我们详细的过一下这个流程,努力将本文写成全网最详细的PyTorch数据部分的解析手册:
- 总解析顺序为:class ImageFolder(DatasetFolder) -> class DatasetFolder(VisionDataset) -> def make_dataset()、getitem ->self.loader(path) 样本读取 ->default_loader()图片读取最终函数
一、datasets.ImageFolder
按照顺序,我们首先要看一下class ImageFolder(DatasetFolder):
这个类:
这个代码中就一个init函数进行初始化,传入了如下的参数:
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
其中最重要的部分是root中的dataset的路径以及transform对应的数据增强内容。
下面进入DatasetFolder类:解析写到代码中
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
目标就是将传入路径下的数据进行进一步组织,每一个数据变成string,方便后续读取
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root) #将路径下数据的类取出
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) #将目录下的数据按照string类型进行组织
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
其中make_dataset:
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class] #class的名称与ID编号有一个对应关系
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index #将排列好的数据dir以及class对应的id进行打包,统一放到instance中并返回
instances.append(item)
return instances
而getitem函数用于访问特定的数据,即传入index并调用sample = self.loader(path)
进行IO读取。
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path) #读取数据
if self.transform is not None:
sample = self.transform(sample) # 数据增强对数据进行处理
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
其中的self.loader(path)
是loader: Callable[[str], Any] = default_loader,
,即:
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
最终调用
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
# st = time.time()
img = Image.open(f)
img = img.convert('RGB')
# print(f'img.convert {path}: {time.time() - st}')
return img
回归最初的函数,我们得到的ImageFolder类主要是想利用samples-即数据的打包处理后的路径
、以及targets-即对应的类别
。
二、torch.utils.data.DataLoader
1 worker数量为0
- 解析顺序:DataLoader(object): -> iter -> _SingleProcessDataLoaderIter -> _MapDatasetFetcher数据读取器
之后我们将处理好的数据ImageFolder类传入torch.utils.data.DataLoader:
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
该类主要是用于提供一个iterable的数据集,方便train函数里面for循环读取预处理好的一个batch的数据。
这个代码相对复杂一些:
首先传入的参数分别为:训练数据的打包路径,batchsize的个数,是否shuffle,worker的数量,是否用pin等
在init函数中首先就是各种传入参数的条件判断,安全保障等。
而这里最重要的函数我觉得是这里的iter函数:
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
这里分两类,第一个是单线程的Dataloader,该方法会进入:
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration #生成下一个要访问的index
data = self._dataset_fetcher.fetch(index) # 调用数据读取方法,传入index
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
该函数首先创建_DatasetKind类的数据读取器,其中会进入_MapDatasetFetcher:
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index): # 获取数据的函数,且是sample的编号
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
而fetch函数中data = [self.dataset[idx] for idx in possibly_batched_index]
会呼应上面的__getitem__
函数,从而调用self.loader+transform通过IO和数据预处理获取到最新的数据,并放到data这个数组中。
2 worker数量大于0
- 解析顺序:
而数据处理部分最重要的为多线程部分:
_MultiProcessingDataLoaderIter(_BaseDataLoaderIter)
这里给了一堆的英文解释,我们不用看这个部分,跟着我的思路阅读即可。
首先是该类中的初始化函数:
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
assert self._num_workers > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
self._worker_result_queue = multiprocessing_context.Queue() 存放处理好的数据
self._worker_pids_set = False
self._shutdown = False
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
self._workers_status = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue() **存放每个worker要取的数据的索引
# index_queue.cancel_join_thread()
w = multiprocessing_context.Process( 根据woker的数量创建多个进程
target=_utils.worker._worker_loop, 具体的数据读取以及预处理的方法
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers))
w.daemon = True #守护进程
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
self._workers_status.append(True)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self._num_workers): 提前生成2倍worker的数据获取请求
self._try_put_index()
这里可以针对某些部分进行重点说明:
在循环中,代码会执行index_queue = multiprocessing_context.Queue(),为每一个进程创建一个新的索引queue,然后调用self._index_queues.append(index_queue)
将所有进程的queue组织到一起_index_queues。
而在_try_put_index
函数中:
def _try_put_index(self):
assert self._tasks_outstanding < 2 * self._num_workers
try:
index = self._next_index() #生成一个batch的数据索引
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
将索引数据放入_index_queues中,因为python是地址更新,所以更新_index_queues[worker_queue_idx]也就是更新上面queue的index_queue位置上的具体值;
self._index_queues[worker_queue_idx].put((self._send_idx, index)) #将该batch的索引放入
self._task_info[self._send_idx] = (worker_queue_idx,) #记录这个batch是哪个进程做的
self._tasks_outstanding += 1
self._send_idx += 1 #batch操作的请求总数加一
而index_queue
将会被传入每个进程:
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers))
下面我们来看重头戏,看看每个进程做的事情是什么:target=_utils.worker._worker_loop
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
num_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
# index_queue:每个进程未来要处理的数据index
# data_queue:处理好的数据放入的位置
try:
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset)
from torch.utils.data import _DatasetKind
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) #建立数据读取器
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
# When using Iterable mode, some worker can exit earlier than others due
# to the IterableDataset behaving differently for different workers.
# When such things happen, an `_IterableDatasetStopIteration` object is
# sent over to the main process with the ID of this worker, so that the
# main process won't send more tasks to this worker, and will send
# `None` to this worker to properly exit it.
#
# Note that we cannot set `done_event` from a worker as it is shared
# among all processes. Instead, we set the `iteration_end` flag to
# signify that the iterator is exhausted. When either `done_event` or
# `iteration_end` is set, we skip all processing step and just wait for
# `None`.
iteration_end = False
watchdog = ManagerWatchdog() #用来检查进程是否活着
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) #获取需要处理的数据的index
except queue.Empty:
continue
if r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r #获取到id以及dir路径
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index) # 将数据通过_DatasetKind.create_fetcher读出来
except Exception as e:
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
data_queue.put((idx, data))# 将IO读出来的数据以及与处理好的数据放到共用的result队列中
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()
当n个进程开启后,就会源源不断的进行数据预处理,而生产者部分我们了解了,那消费者是如何运转的呢?
读取的iterable数据当然是通过调用__next__
函数:
def __next__(self):
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
该函数会进入:data = self._next_data()
,即_MultiProcessingDataLoaderIter
实现的
def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx] #查看收到的batch size的请求是否已经被处理过
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active,
#如果_rcvd_idx 这个batch的数据被处理过,那==2,没有被处理的话就是1(只有进程号,没有data数据),如果进程还活着,那么break出来(相信这个进程还能把这个数据做出来)
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
self._shutdown_workers() # 进程全挂了进入这里
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2: #如果数据已经被生成,那么pop出来
data = self._task_info.pop(self._rcvd_idx)[1] # 1对应的是数据data
return self._process_data(data) # 因为已经取出来一组处理好的batch数据,所以再放进去一个新的未处理的batch
assert not self._shutdown and self._tasks_outstanding > 0
# 如果数据一个都没有,主进程就会去找数据了
idx, data = self._get_data() # 如果当前所需的batch没有被做出来,那么从result的池子中顺序取一个batch id + data
self._tasks_outstanding -= 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
self._shutdown_worker(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,) #顺序取的这个如果是未来的,那么提前存好,然后继续进入循环
else:
del self._task_info[idx]
return self._process_data(data) #取出来这个数据后,则发起新的IO读取请求
简单来说,主进程发起各种index的请求,并添加到每一个进程的index队列中,并且每一个batch都有一个编号,每个进程领一个编号去做。由于进程的执行速度不一样,所有batch是否做完并不是严格按照发起顺序的,所以后面的数据可以提前放进来,当主进程找不到所需的batch,那它便会通过self._get_data()
从_worker_result_queue
中找到并在他自己的一个变量中维护,如果恰好从池子中拿的是所需的,就返回,若不是则提前存下来,方便下次需要这个batch的时候直接用。然后while Ture一直循环,直到拿到所需要的batch。
如果进程取数据的时候,其对应的index队列中不存在提前放好的index,则会持续continue;
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) #获取需要处理的数据的index
except queue.Empty: # 如果没有数据则一直循环等待
continue
而这个index是一旦有Batch数据被取出,新的数据index就会放入,具体是哪个进程的index则会调用下面的cycle循环得出:
worker_queue_idx = next(self._worker_queue_idx_cycle)
self._index_queues[worker_queue_idx].put((self._send_idx, index))
所以这里每次取一个batch数据,就会循环产生下一个进程ID,产生新的index放进该进程对应的index队列中。所以并不是取进程n的一个Batch就一定把index补充给进程n,还是轮流来。