小白科研笔记:深入理解mmDetection框架——数据输入

1. 前言

这篇博客讨论mmDetection框架的数据输入过程。主要讨论SA-SSD的数据流。

2. 数据输入

2.1 DataLoader类

Pytorch训练网络的简单例程可参考这篇简书笔记。数据输入的代码片段如下所示。Pytorch提供DataLoader类用于处理数据的输入,提供做训练的网络输入值(inputs)和真值(labels)。它提供很多数据输入的便捷操作,比如输入数据批处理大小(batch_size),数据集是否需要随机打乱(shuffle),数据集读取使用的线程个数(num_workers)。

from torch.utils.data import DataLoader

trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=4)

for i, data in enumerate(trainloader, 0):
    # get the input
    inputs, labels = data

对于DataLoader而言,我最为关心的问题是:它是怎样读取输入数据的?从代码上看,通过调用enumerate函数读取输入数据。再刨根问底地分析,enumerate函数则是调用DataLoader类的类成员变量trainset__getitem__函数。trainset是一个类,它公有继承自类dataset。说起来是不是有点绕呢?我用一个示意图去解释它们。在图中,Get Data Function是使用者需要定义的函数,比如训练数据读取,数据增广等。

小白科研笔记:深入理解mmDetection框架——数据输入_第1张图片
图1:DataLoader的调用关系图

2.2 mmDetection中的数据输入

上一节讨论了PytorchDataloader类的原理图示,这一节看看mmDetection中数据输入的原理图示。有一点要明确的是,mmDetection的底层调用的还是Dataloader。下面的原理图示以三维目标检测SA-SSD为例子。

小白科研笔记:深入理解mmDetection框架——数据输入_第2张图片
图2:mmDetection初始化DataLoader类的图示

由上图可见,mmDetection首先调用函数get_dataset,根据配置文件cfg.data.train初始化一个数据集处理的类KittiLiDAR(相当于图1中的类trainset),然后再调用函数build_dataloader,使用类KittiLiDAR初始化类Dataloader

上述的流程需要频繁地做类实例化操作,需要用到函数obj_from_dict。初始化一个类需要初始条件(比如img=cam_1flip=True等等),我们把初始条件写进一个字典型变量里面(比如{"img":cam_1,"flip":True,...}),然后调用函数obj_from_dict,输入需要初始化类的类型以及初始化字典型变量,就可以完成这个类的初始化。

3. 结束语

这篇博客讨论了Dataloader类访问训练数据的流程,以及mmDetection调用Dataloader类的流程。

你可能感兴趣的:(computer,vision论文代码分析)