简析SA-SSD使用mmdect框架生成训练数据的代码细节

1. 前言

我的上一篇博客简要分析了SA-SSD在预处理训练评估的框架。这篇博客将更近一步去分析SA-SSD使用mmdetection生成训练集的过程。考虑到SSD是基于mmdetection实现的,所以这篇博客也会分析涉及到mmdetection的一部分代码。因为我是小白,所以我会把自己不懂的地方都会记下来。

2. 简析mmdetection的训练流程

2.1 训练总体流程

考虑到SSD是基于mmdetection实现的,首先应该认识mmdetection框架下的训练流程。SSD的训练学习代码如下所示:

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

    train_dataset = get_dataset(cfg.data.train)

    train_detector(
        model,
        train_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)

从上述代码中,可见mmdetection的训练流程大致分为三步:

第一步,初始化Detector,对应函数build_detector

第二步,加载训练数据集,对应函数get_dataset

第三步,训练Detector,对应函数train_detector

在上述代码中,cfg起着非常重要的作用。cfg.model代表模型中的超参数。cfg.data.train代表训练数据集的信息。cfg记录着优化器,学习率等等训练细节。入门的小白可以仔细查看一下SA-SSD中的cfg文件。

函数train_detector是一个比较自动的模块。它会根据cfg中设定的训练细节训练Detector。我不需要修改它,因此train_detector就不是我的重点。作为小白,我比较关注get_datasetbuild_detector。在SA-SSD中,训练数据只是视场范围内的点云。假设我的模型需要雷达全部点云呢,或是结合RGB图像呢,或是结合双目RGB图像呢,抑或考虑IMU信息呢,我怎样去生成我所需要的train_dataset呢?这是一个很大的问题。build_detector的重要性不言而喻。这篇博客将分析第一个问题。

2.2 深入理解get_dataset

首先理解cfg.data.train。在上一篇博客已经贴出它的信息,这里就不再重复。在cfg.data.train中,与训练数据相关的参数是:数据集路径,图像尺寸,图像裁剪比(跟FPN网络相关),图像归一化参数,检测目标(比如指定只检测汽车的3D目标),3D目标真值,点云体素化参数,Anchor生成参数。

再去理解函数get_dataset。上一篇文章谈到get_dataset代码中多次出现obj_from_dict。对这个核心函数的分析可见这篇博客。代码如下所示:

def obj_from_dict(info, parrent=None, default_args=None):
    """Initialize an object from dict.
    The dict must contain the key "type", which indicates the object type, it
    can be either a string or type, such as "list" or ``list``. Remaining
    fields are treated as the arguments for constructing the object.
    Args:
        info (dict): Object types and arguments.
        module (:class:`module`): Module which may containing expected object
            classes.
        default_args (dict, optional): Default arguments for initializing the
            object.
    Returns:
        any type: Object built from the dict.
    """
    assert isinstance(info, dict) and 'type' in info
    assert isinstance(default_args, dict) or default_args is None
    args = info.copy()
    obj_type = args.pop('type')
    if mmcv.is_str(obj_type):
        if parrent is not None:
            obj_type = getattr(parrent, obj_type)
        else:
            obj_type = sys.modules[obj_type]
    elif not isinstance(obj_type, type):
        raise TypeError('type must be a str or valid type, but got {}'.format(
            type(obj_type)))
    if default_args is not None:
        for name, value in default_args.items():
            args.setdefault(name, value)
    return obj_type(**args)

看上去挺复杂的。注释说这段函数的作用是Initialize an object from dict,通俗理解是根据字典型变量info去指定初始化一个parrent类对象。说白了,就是字典型变量中储存了类的初始化变量。核心调用是getattr。总之,obj_from_dict是一种做指定初始化的功能函数。

Ok,理解了obj_from_dict,是时候刚一波get_dataset的源码:

def get_dataset(data_cfg):
	# 生成index文件的实例,'ann_file'是data_root + 'ImageSets/train.txt'
	# num_dset 就是训练数据总数
    if isinstance(data_cfg['ann_file'], (list, tuple)):
        ann_files = data_cfg['ann_file']
        num_dset = len(ann_files)
    else:
        ann_files = [data_cfg['ann_file']]
        num_dset = 1

	# SA-SSD没有使用它,按照else,生成 N 个 None
    if 'proposal_file' in data_cfg.keys():
        if isinstance(data_cfg['proposal_file'], (list, tuple)):
            proposal_files = data_cfg['proposal_file']
        else:
            proposal_files = [data_cfg['proposal_file']]
    else:
        proposal_files = [None] * num_dset
    assert len(proposal_files) == num_dset

	# SA-SSD没有使用它,算法不需要图像,'img_prefix'=None
	# 按照else,生成 N 个 None
	# 如果需要RGB的话,可以在cfg中写img_prefix=data_root + 'train2017/'相应路径
    if isinstance(data_cfg['img_prefix'], (list, tuple)):
        img_prefixes = data_cfg['img_prefix']
    else:
        img_prefixes = [data_cfg['img_prefix']] * num_dset
    assert len(img_prefixes) == num_dset

	# 按照data_cfg['generator']的参数,初始化voxel_generator,用于预处理点云体素化
    if 'generator' in data_cfg.keys() and data_cfg['generator'] is not None:
        generator = obj_from_dict(data_cfg['generator'], voxel_generator)
    else:
        generator = None

	# 按照data_cfg['augmentor']的参数,初始化point_augmentor,用于提供3D目标真值
    if 'augmentor' in data_cfg.keys() and data_cfg['augmentor'] is not None:
        augmentor = obj_from_dict(data_cfg['augmentor'], point_augmentor)
    else:
        augmentor = None

	# 按照data_cfg['anchor_generator']的参数,初始化anchor3d_generator,用于提供3DAnchor
    if 'anchor_generator' in data_cfg.keys() and data_cfg['anchor_generator'] is not None:
        anchor_generator = obj_from_dict(data_cfg['anchor_generator'], anchor3d_generator)
    else:
        anchor_generator = None

	# 按照data_cfg['target_encoder']的参数,初始化bbox3d_target
	# SA-SSD中貌似没有使用,返回 None
    if 'target_encoder' in data_cfg.keys() and data_cfg['target_encoder'] is not None:
        target_encoder = obj_from_dict(data_cfg['target_encoder'], bbox3d_target)
    else:
        target_encoder = None

    dsets = []
    # 装填用于训练的数据
    for i in range(num_dset):
    	# 定义字典型变量data_info ,用于引导训练数据的装填
        data_info = copy.deepcopy(data_cfg)
        data_info['ann_file'] = ann_files[i]
        data_info['proposal_file'] = proposal_files[i]
        data_info['img_prefix'] = img_prefixes[i]
        if generator is not None:
            data_info['generator'] = generator
        if anchor_generator is not None:
            data_info['anchor_generator'] = anchor_generator
        if augmentor is not None:
            data_info['augmentor'] = augmentor
        if target_encoder is not None:
            data_info['target_encoder'] = target_encoder
        # 使用data_info去实例化datasets
        dset = obj_from_dict(data_info, datasets)
        dsets.append(dset)
    if len(dsets) > 1:
    	# 从上述操作中,每一个训练数据都是一个datasets类
    	# 使用ConcatDataset,把所有datasets类,统一变成一类datasets类
        dset = ConcatDataset(dsets)
    else:
        dset = dsets[0]
    return dset

上述代码中,最为核心的是这句话dset = obj_from_dict(data_info, datasets)。我需要看看类datasets类的实例化过程。

然而datasets类是一个Virtual类的存在,

__all__ = [
    'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
    'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
    'show_ann', 'get_dataset', 'KittiLiDAR','KittiVideo', 'VOCDataset'
]

回头再想想obj_from_dict,更有深层次的理解。根据字典型变量info去指定初始化一个parrent类对象。如果parrent类是一个虚类,它会根据info的变量自动地匹配一个Matched的子类,去指定初始化这个子类的实例。

毫无疑问,肯定是生成datasets类子类中的KittiLiDAR类。它会在下一节分析。

2.3 简析KittiLiDAR类

该类的初始化代码如下所示。从下面代码可以看出字典型变量data_info的内部变量跟KittiLiDAR初始化所需要的变量是匹配的。

class KittiLiDAR(Dataset):
    def __init__(self, root, ann_file,
                 img_prefix,
                 img_norm_cfg,
                 img_scale=(1242, 375),
                 size_divisor=32,
                 proposal_file=None,
                 flip_ratio=0.5,
                 with_point=False,
                 with_mask=False,
                 with_label=True,
                 class_names = ['Car', 'Van'],
                 augmentor=None,
                 generator=None,
                 anchor_generator=None,
                 anchor_area_threshold=1,
                 target_encoder=None,
                 out_size_factor=2,
                 test_mode=False):
        self.root = root
        self.img_scales = img_scale if isinstance(img_scale,
                                                  list) else [img_scale]
        assert mmcv.is_list_of(self.img_scales, tuple)
        # normalization configs
        self.img_norm_cfg = img_norm_cfg

        # flip ratio
        self.flip_ratio = flip_ratio

        # size_divisor (used for FPN)
        self.size_divisor = size_divisor
        self.class_names = class_names
        self.test_mode = test_mode
        self.with_label = with_label
        self.with_mask = with_mask
        self.with_point = with_point
        # 获取KITTI相关各种数据的前缀路径
        self.img_prefix = osp.join(root, 'image_2')
        self.right_prefix = osp.join(root, 'image_3')
        self.lidar_prefix = osp.join(root, 'velodyne_reduced')
        self.calib_prefix = osp.join(root, 'calib')
        self.label_prefix = osp.join(root, 'label_2')

        with open(ann_file, 'r') as f:
            self.sample_ids = list(map(int, f.read().splitlines()))

        if not self.test_mode:
            self._set_group_flag()

        # transforms
        self.img_transform = ImageTransform(
            size_divisor=self.size_divisor, **self.img_norm_cfg)

        # voxel
        self.augmentor = augmentor
        self.generator = generator
        self.target_encoder = target_encoder
        self.out_size_factor = out_size_factor
        self.anchor_area_threshold = anchor_area_threshold
        # anchor
        if anchor_generator is not None:
            feature_map_size = self.generator.grid_size[:2] // self.out_size_factor
            feature_map_size = [*feature_map_size, 1][::-1]
            anchors = anchor_generator(feature_map_size)
            self.anchors = anchors.reshape([-1, 7])
            self.anchors_bv = rbbox2d_to_near_bbox(
                self.anchors[:, [0, 1, 3, 4, 6]])
        else:
            self.anchors=None

DataLoader中,会需要这个类的__getitem__函数:

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(idx)
        while True:
            data = self.prepare_train_img(idx)
            if data is None:
                idx = self._rand_another(idx)
                continue
            return data

来看看函数prepare_train_img,输出字典型变量data

    def prepare_train_img(self, idx):
        sample_id = self.sample_ids[idx]

        # load image
        img = mmcv.imread(osp.join(self.img_prefix, '%06d.png' % sample_id))

        img, img_shape, pad_shape, scale_factor = self.img_transform(img, 1, False)

        objects = read_label(osp.join(self.label_prefix, '%06d.txt' % sample_id))
        calib = Calibration(osp.join(self.calib_prefix, '%06d.txt' % sample_id))

        gt_bboxes = [object.box3d for object in objects if object.type not in ["DontCare"]]
        gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
        gt_types = [object.type for object in objects if object.type not in ["DontCare"]]

        #gt_labels = np.ones(len(gt_bboxes), dtype=np.int64)

        # transfer from cam to lidar coordinates
        if len(gt_bboxes) != 0:
            gt_bboxes[:, :3] = project_rect_to_velo(gt_bboxes[:, :3], calib)

        img_meta = dict(
            img_shape=img_shape,
            sample_idx=sample_id,
            calib=calib
        )

        data = dict(
            img=to_tensor(img),
            img_meta = DC(img_meta, cpu_only=True)
        )

        if self.anchors is not None:
            data['anchors'] = DC(to_tensor(self.anchors.astype(np.float32)))

        if self.with_mask:
            NotImplemented

        if self.with_point:
            points = read_lidar(osp.join(self.lidar_prefix, '%06d.bin' % sample_id))

        if self.augmentor is not None and self.test_mode is False:
            sampled_gt_boxes, sampled_gt_types, sampled_points = self.augmentor.sample_all(gt_bboxes, gt_types)
            assert sampled_points.dtype == np.float32
            gt_bboxes = np.concatenate([gt_bboxes, sampled_gt_boxes])
            gt_types = gt_types + sampled_gt_types
            assert len(gt_types) == len(gt_bboxes)

            # to avoid overlapping point (option)
            masks = points_in_rbbox(points, sampled_gt_boxes)
            points = points[np.logical_not(masks.any(-1))]

            # paste sampled points to the scene
            points = np.concatenate([sampled_points, points], axis=0)

            # select the interest classes
            selected = [i for i in range(len(gt_types)) if gt_types[i] in self.class_names]
            gt_bboxes = gt_bboxes[selected, :]
            gt_types = [gt_types[i] for i in range(len(gt_types)) if gt_types[i] in self.class_names]

            # force van to have same label as car
            gt_types = ['Car' if n == 'Van' else n for n in gt_types]
            gt_labels = np.array([self.class_names.index(n) + 1 for n in gt_types], dtype=np.int64)

            self.augmentor.noise_per_object_(gt_bboxes, points, num_try=100)
            gt_bboxes, points = self.augmentor.random_flip(gt_bboxes, points)
            gt_bboxes, points = self.augmentor.global_rotation(gt_bboxes, points)
            gt_bboxes, points = self.augmentor.global_scaling(gt_bboxes, points)

        if isinstance(self.generator, VoxelGenerator):
            #voxels, coordinates, num_points = self.generator.generate(points)
            voxel_size = self.generator.voxel_size
            pc_range = self.generator.point_cloud_range
            grid_size = self.generator.grid_size

            keep = points_op_cpu.points_bound_kernel(points, pc_range[:3], pc_range[3:])
            voxels = points[keep, :]
            coordinates = ((voxels[:, [2, 1, 0]] - np.array(pc_range[[2,1,0]], dtype=np.float32)) / np.array(
                voxel_size[::-1], dtype=np.float32)).astype(np.int32)
            num_points = np.ones(len(keep)).astype(np.int32)

            data['voxels'] = DC(to_tensor(voxels.astype(np.float32)))
            data['coordinates'] = DC(to_tensor(coordinates))
            data['num_points'] = DC(to_tensor(num_points))

            if self.anchor_area_threshold >= 0 and self.anchors is not None:
                dense_voxel_map = sparse_sum_for_anchors_mask(
                    coordinates, tuple(grid_size[::-1][1:]))
                dense_voxel_map = dense_voxel_map.cumsum(0)
                dense_voxel_map = dense_voxel_map.cumsum(1)
                anchors_area = fused_get_anchors_area(
                    dense_voxel_map, self.anchors_bv, voxel_size, pc_range, grid_size)
                anchors_mask = anchors_area > self.anchor_area_threshold
                data['anchors_mask'] = DC(to_tensor(anchors_mask.astype(np.uint8)))

            # filter gt_bbox out of range
            bv_range = self.generator.point_cloud_range[[0, 1, 3, 4]]
            mask = filter_gt_box_outside_range(gt_bboxes, bv_range)
            gt_bboxes = gt_bboxes[mask]
            gt_labels = gt_labels[mask]

        else:
            NotImplementedError

        # skip the image if there is no valid gt bbox
        if len(gt_bboxes) == 0:
            return None

        # limit rad to [-pi, pi]
        gt_bboxes[:, 6] = limit_period(
            gt_bboxes[:, 6], offset=0.5, period=2 * np.pi)

        if self.with_label:
            data['gt_labels'] = DC(to_tensor(gt_labels))
            data['gt_bboxes'] = DC(to_tensor(gt_bboxes))


        return data

字典型变量data包含img,img_meta,anchors,voxel(视场范围的点云),coordinates,num_points,anchor_mask,gt_labes,gt_bboxes。总之它包含输入值也包含3D目标检测真值。

函数prepare_test_img构造和prepare_train_img几乎一样。

2.4 理解训练过程的一个问题

这一节分析2.1节提及的函数train_detector。我一直都有一个疑问,就是train_dataset提供那么多数据,它是怎样把指定的变量喂给Model的?比如SA-SSD只需要喂点云,而train_dataset有那么多数据。代码在哪里指定输入对象的?

	# 前向计算过程中,是直接把data塞进model中的
	# 从KITTILiDAR可知,data包含img,img_meta,anchors,voxel等等变量
	# 在前向计算中,是哪些变量发挥作用呢?
    for data in enumerate(data_loader):
        results = model(**data)

找了半天,答案并不在函数train_detector

答案在函数build_detector中。SA-SSD属于父类SingleStageDetector。父类中前向计算函数forward就指明了模型的输入:

	# img, img_meta, **kwargs都属于train_dataset push的数据
    def forward_train(self, img, img_meta, **kwargs):
        batch_size = len(img_meta)
        ret = self.merge_second_batch(kwargs)
        vx = self.backbone(ret['voxels'], ret['num_points'])
        # 后续代码略去

函数merge_second_batch代码如下所示。kwargs是字典型变量,从它那里搜索voxelsnum_points作为输入。3D检测真值gt_labelsgt_bboxes也将保存。

    def merge_second_batch(self, batch_args):
        ret = {}
        for key, elems in batch_args.items():
            if key in [
                'voxels', 'num_points',
            ]:
                ret[key] = torch.cat(elems, dim=0)
            elif key == 'coordinates':
                coors = []
                for i, coor in enumerate(elems):
                    coor_pad = F.pad(
                        coor, [1, 0, 0, 0],
                        mode='constant',
                        value=i)
                    coors.append(coor_pad)
                ret[key] = torch.cat(coors, dim=0)
            elif key in [
                'img_meta', 'gt_labels', 'gt_bboxes',
            ]:
                ret[key] = elems
            else:
                ret[key] = torch.stack(elems, dim=0)
        return ret

2.5 阶段性小结

对于SA-SSD来说,2.1节提及的train_dataset是由全体训练数据以KittiLiDAR类构成的。

3. 小结

mmdetection是一个工程性很强的开源框架。代码稍微抽象,封装性强。相比上一篇博客,这篇博客更加细致地分析了输入数据的流向。可能会有别人说,直接使用它难道不香吗。博主作为小白,在使用mmdetection框架遇到bug,只有明白框架的来龙去脉,才能有能力去修改它。下一篇博客将会分析SA-SSD的网络架构。

你可能感兴趣的:(computer,vision论文代码分析)