mmdetection之dataloader构建

文章目录

  • 前言
  • 1、总体流程
  • 2、实例化dataloader
    • 2.1. GroupSampler类实现
    • 2.2. BatchSampler类
  • 3、读取一个batch数据流程
  • 总结


前言

 本篇将介绍mmdetection如何构建dataloader类的。dataloader主要控制数据集的迭代读取。与之配套的是首先实现dataset类。关于dataset类的实现请转mmdetection之dataset类构建。


1、总体流程

 在pytorch中,Dataloader实例构建需要以下重要参数(截取dataloader源码)。

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
            returns a batch of indices at a time. Mutually exclusive with
            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
            and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.

 简单介绍下各个参数含义:

dataset:就是继承Dataset类的实例;
batch_size: 批次大小
shuffle: True.在开始新的一轮epoch时,是否会重新打乱数据
sampler:迭代器:里面存储着数据集的下标(可能被打乱/顺序)。是迭代器。
batch_samper: 迭代sampler中下标,然后根据下标去dataset中取出batch_size个数据。
collate_fn:将batch个数据整合进一个list,调整宽和高。

 可能上面几个参数定义有点而蒙。没关系,只需记住dataset,sampler,batch_sampler,dataloader均是迭代器即可。至于迭代器:理解为可以被 for … in dataset:使用即可。
 既然Dataloader主要参数有了,那么现在看下mmdetection中如何build_dataloader的。接下来我打算分两部分进行讲解:
 (1)如何实例化一个dataloader对象。如下图所示:mmdetection中主要实现下边四个参数。GroupSamper继承自torch的sampler类。shuffle大多数都是True。而batch_sampler参数mmdetection使用是pytorch中已实现的BatchSampler类。
mmdetection之dataloader构建_第1张图片

 (2)读取一个batch数据流程。

2、实例化dataloader

2.1. GroupSampler类实现

 dataset的实现请转dataset类构建。这里我贴下GroupSampler源码:

class GroupSampler(Sampler):

    def __init__(self, dataset, samples_per_gpu=1):
        assert hasattr(dataset, 'flag')
        self.dataset = dataset
        self.samples_per_gpu = samples_per_gpu
        self.flag = dataset.flag.astype(np.int64)  #
        self.group_sizes = np.bincount(self.flag)  # np.bincount()函数统计 下标01出现的次数。
        self.num_samples = 0
        for i, size in enumerate(self.group_sizes):
            self.num_samples += int(np.ceil(
                size / self.samples_per_gpu)) * self.samples_per_gpu

    def __iter__(self):
        indices = []
        for i, size in enumerate(self.group_sizes):      # self.group_sizes = [942,4096] ;其中942代表长度比例<1的图像数量;
            if size == 0:
                continue
            indice = np.where(self.flag == i)[0]         # 提取出self.flag中等于当前i的下标。 self.flag顺序存储着训练集中所有图像的aspect-ratio
            assert len(indice) == size
            np.random.shuffle(indice)                    # 这里将下标打乱了
            num_extra = int(np.ceil(size / self.samples_per_gpu)
                            ) * self.samples_per_gpu - len(indice)
            indice = np.concatenate(
                [indice, np.random.choice(indice, num_extra)])
            indices.append(indice)
        indices = np.concatenate(indices)                # 合并陈一个list,长度为5011的
        indices = [                                      # 按照batch将list划分:若batch=1,则将列表划分成长度为5011的数组。
            indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
            for i in np.random.permutation(
                range(len(indices) // self.samples_per_gpu))
        ]
        indices = np.concatenate(indices)
        indices = indices.astype(np.int64).tolist()
        assert len(indices) == self.num_samples
        return iter(indices)

    def __len__(self):
        return self.num_samples

 其实主要实现了__iter__方法使其成为一个迭代器。而大致思路就是:假如我有一个5000张图像的数据集。那么数据集下标是0~4999.通过np.random.shuffle打乱5000个下标。假如batch是2,则共得到2500对。将这2500对以数组形式存于indices这个list中。最终通过iter(indices)迭代

2.2. BatchSampler类

 这部分mmdetection使用的是pytorch源码。我贴下源码:

class BatchSampler(Sampler[List[int]]):

    def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:

        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore

 从源码可以看出:BatchSampler以sampler初始化的。同时也实现了__iter__方法,每迭代够一个batch,则借助生成器yield batch,即返回一个batch数据。

3、读取一个batch数据流程

 这里我想用张图说明下:文字不易描述:
mmdetection之dataloader构建_第2张图片

总结

 本文主要介绍mmdetection如何通过实现dataset,sampler来构造一个Dataloader,另外,展示了dataloader内部是如何迭代每个批次数据的。

你可能感兴趣的:(mmdetection,python,mmdetection,pytorch,目标检测)