之前已经处理出了kitti format的waymo dataset文件,但距离模型真正处理数据还有很长一段路。虽然在runner的train里,data的获取很简单,就是一个enumerate(dataloader):
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
# model forward && calc loss && backprop
self.call_hook('after_train_iter')
# optimize, save logs
self._iter += 1
但事实上数据要从file system读入到内存,经过pipeline类的格式处理、数据增强,进入gpu,然后才能给已经load进gpu的model进行训练。
在这个过程中有几个很关键的类在起作用,有dataset类,sampler类,dataloader类,mmdet的这几个类基本上继承自pytorch类。接下来结合waymo的代码介绍一下这三个类。
和c++里的set等stl类一样,iterator可以帮助我们遍历整个类,或者说容器的内容。在这里主要是sampler和dataloader有自己的iterator,这样就能实现for i,data in enumerate(dataloader): output = model(data)
的功能了。
对于dataloader来说,为了支持iterator,需要实现几个方法:
.__len__(self)
,一般返回类的长度,比如数据有几个frame。.__getitem__(self)
,定义获取容器指定元素的行为,这样就可以以data[index]的下标方式访问类里的信息。.__iter__(self)
:定义当迭代容器中的元素时的行为,我的理解是,这样使得for之初能获取到dataloader的iterator.__next__(self)
:定义迭代器的迭代规则,比如按某种顺序遍历整个set,for循环每次都会调用一下next。主要实现了数据从文件读到内存的功能,定义了__getitem__,这样当你dataset[index]的时候,就能得到某个frame的数据。主要类型有两种,第一种是Map-style dataset,就是只有set[index]的功能;第二种是Iterable-style dataset,实现了__iter__,也就是可以直接for xxx in enumerate(dataset)。一般用第一种。
我是用了transfusion改过的waymo dataset,主要的改动是支持了读入多张图片和对应label,原本只支持一张。
继承关系:CustomWaymoDataset–>KittiDataset–> Custom3DDataset -->torch_dataset
接下来分别讲讲每个部分的内容
在Custom3DDataset中,我们初始化了dataset的path,以及读入ann_file,即之前处理出的.pkl标注文件,记录到self.data_infos里。这样我们就知道了每张图片和对应calib等信息的path。对于指定了load_interval的情况,还会self.data_infos = self.data_infos[::load_interval]
。还有比较重要的参数是pipeline,以及相机个数num_view。
仅列出主要部分
data = self.prepare_train_data(idx)
return data
input_dict = self.get_data_info(index)
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
首先提取单帧的所有信息,放到一个dict里,然后pre_pipeline初始化dict里的一些key值,最后交给自定义的pipeline,把数据处理好再返回。(pipeline过程中好像数据已经进gpu了,我的理解,这里的是一个指向gpu的指针?)
主要从self.data_infos里得到该index的info,包括img,point cloud文件的path,点云到相机坐标的变换lidar2img,以及labels,即annotations。
之前也提到,waymo的pkl文件里只有img0的path和calib0的参数,所以返回结果也只有img0的。在这里transfusion修改了一下,根据img0的path得到了img0~4
的path和calib0~4
的lidar2img。
读入label会调用下面的get_ann_info函数,最后return input_dict
把self.data_infos的gt boxes整理好格式返回。
pipeline不被定义在dataset里,detr3d的操作如下,会从path中读img,做数据增强,filter等操作。最后的collect3d就相当于是只保留dict里的这几个keys。具体的内容可以查看class pipeline_name下的__call__函数。
train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='PhotoMetricDistortionMultiViewImage'),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'])
]
dataset已经实现了getitem的方法,但是如何对dataset进行采样还需要sampler定义。它主要就是每次训练的时候生成index,本身是一个类,也有自己的iterator,有多种sampler,比如SequentialSampler顺序采样,RandomSampler, WeightedRandomSampler等。
torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和Iterable-style Dataset,支持单进程/多进程,还可以通过参数设置如 sampler, batch size, pin memory 等自定义数据加载顺序以及控制数据批处理功能。其接口定义如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
num_worker定义了用于load数据的子进程,非0的时候,模型在跑data的时候,子进程会处理新的data,buffer大小为prefetch_factor,batch_size定义了迭代几次sampler来获取index,形成一个batch。也可以用batch_sampler一次性获得一个batch的index。
在for i,data in enumerate(dataloader)的时候,dataloader一次迭代会做很多事情,首先迭代self.sampler来获得key_list,即一个batch的index。接着调用dataset获取input_dict,这个时候形成一个list of dict,即[in1,in2,in3…],最后为了形成模型需要的nchw形式的tensor,还需要调用collatefn函数,把list of dict变成dict of list,即把[dict(a:x1,b:y1),dict(a:x2,b:y2),dict(…)]变成:
dict(a:[x1,x2,x3…],b:…)
整个过程抽象为
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
当然,在detr3d里,input tensor的形式是batch_size* num_view *chw的。不过bs=1。
具体的父类等依赖关系看教程二,多线程的逻辑稍有不同,不过我暂时没必要理解,不看了。
见另一篇文章
参考:
PyTorch 小课堂开课啦!带你解析数据处理全流程(一)
PyTorch 小课堂!带你解析数据处理全流程(二)