理解Detectron2中数据读取以及在线数据增强流程

文章目录

    • 一、Detectron2中读取数据的流程
    • 二、读取数据、在线数据增强并显示
    • 三、Detectron2官方对数据增强部分的描述

以目标检测为例,探讨Detectron2中的数据读取流程以及Detectron2中是如何进行在线数据增强。

一、Detectron2中读取数据的流程

  • 根据cfg配置文件,建立一个Trainer(cfg)
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
  • Trainer类是由DefaultTrainer类派生而来
class Trainer(DefaultTrainer):
  • 再看DefaultTrainer
class DefaultTrainer(TrainerBase)

通过cfg文件,使用modeloptimizerdataloader来定义了一个具有默认训练逻辑的类,这里的dataloader就是数据训练器中数据获取部分。
扩展这个类,因为篇幅关系,只展开训练器中获取网络、优化器、以及本文关注的数据获取部分


class DefaultTrainer(TrainerBase):

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
		...
        # Assume these objects must be constructed in this order.
        # 这里就是通过cfg来获取网络、优化器、数据
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)
	...
  • 进入self.build_train_loader(cfg)函数
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg)

发现,直接调用build_detection_train_loader(cfg),这个返回的就是数据的一个迭代器

  • build_detection_train_loader函数
def build_detection_train_loader(
    dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):
    """
    Build a dataloader for object detection with some default features.
    if isinstance(dataset, list):
        dataset = DatasetFromList(dataset, copy=False)
    if mapper is not None:
        dataset = MapDataset(dataset, mapper)
    if sampler is None:
        sampler = TrainingSampler(len(dataset))
    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
    return build_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
    )

这个函数就是为目标检测任务创建一个 data_loader
可以看到对dataset有两个操作,一个是DatasetFromList,一个是MapDataset
MapDataset函数至关重要,要想增加在线数据增强操作,就得定义一个mappper,adelaidet里面的写法如下:

    def build_train_loader(cls, cfg):
        mapper = DatasetMapperWithBasis(cfg, True)
        return build_detection_train_loader(cfg, mapper)

这个mapper里包含了一些自定义的操作。
如果想看mapper里的具体写法,可以去了解adelaidet或者detectron2中
.data.dataset_mapper写法的区别

二、读取数据、在线数据增强并显示

这里笔者将detectron2.data里的内容抽取出来,并制作成一个demo。
官方默认在线数据增强模块中,只有Resize模块,这里添加了对比度变换的数据增强。
detectron2.data.transforms里,官方已经定义了很多常规的数据增强方式
我们这里对比度变化的数据增强同样采用

augmentataion_impl里的 RandomContrast函数


class RandomContrast(Augmentation):
    def __init__(self, intensity_min, intensity_max):
        """
        Args:
            intensity_min (float): Minimum augmentation
            intensity_max (float): Maximum augmentation
        """
        super().__init__()
        self._init(locals())

RandomContrast中有两个参数,分别对应对比度调整最小值,对比度调整最大值,如果我们需要使用这个数据增强方式,就得在配置cfg文件中配置这两个参数

  • 首先在detectron2.config.defaults.py文件中,添加默认参数
_C.INPUT.RANDOM_CONTRAST = CN({"ENABLED": False})
_C.INPUT.RANDOM_CONTRAST.SACLE_MIN = 1
_C.INPUT.RANDOM_CONTRAST.SACLE_MAX = 1

然后在detectron2.data.detection_utils.py中的build_augmentation中添加配置文件的解析字段,博主完整如下:

def build_augmentation(cfg, is_train):
    """
    Create a list of default :class:`Augmentation` from config.
    Now it includes resizing and flipping.

    Returns:
        list[Augmentation]
    """
    if is_train:
        min_size = cfg.INPUT.MIN_SIZE_TRAIN
        max_size = cfg.INPUT.MAX_SIZE_TRAIN
        sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
    else:
        min_size = cfg.INPUT.MIN_SIZE_TEST
        max_size = cfg.INPUT.MAX_SIZE_TEST
        sample_style = "choice"
    augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
    if is_train and cfg.INPUT.RANDOM_FLIP != "none":
        augmentation.append(
            T.RandomFlip(
                horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
                vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
            )
        )
    # add contrast data augmentation by hjxu in 2020.11.30
    if is_train and cfg.INPUT.RANDOM_CONTRAST.ENABLED:
        augmentation.append(
            T.RandomContrast(
                intensity_min=cfg.INPUT.RANDOM_CONTRAST.SACLE_MIN,
                intensity_max=cfg.INPUT.RANDOM_CONTRAST.SACLE_MAX,
            )
        )
    return augmentation
  • 来看看现在共有哪些数据处理方法
	augmentations = utils.build_augmentation(cfg, is_train)
    print("1:", augmentations)
    augmentations = T.AugmentationList(augmentations)
    print("2:", augmentations)

打印如下,为了美观,这里删除部分部分信息

1: [ResizeShortestEdge.., RandomFlip(), RandomContrast(intensity_min=0.3, intensity_max=2)]
2: AugmentationList[ResizeShortestEdge..., RandomFlip(), RandomContrast(intensity_min=0.3, intensity_max=2)]
  • 数据增强操作添加进来,开始模拟数据输入,并且在屏幕上显示
    以下分为四个步骤

0、配置cfg文件
1、读取coco数据集
2、定义 augmentations
3、利用augmentations对图像进行增强操作
4、利用transforms.apply_box(bbox)对目标框进行操作
5、利用opencv显示图像

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""                  
*  * *** *  * *  *      
*  *  *   **  *  *             
****  *   **  *  *                 
*  *  *   **  *  *         
*  * **  *  * ****  

@File     :detectron2/data_aug_demo.py  
@Date     :2020/12/1 上午11:58  
@Require  :   
@Author   :hjxu2016
@Funtion  :
"""

from detectron2.config import get_cfg
from detectron2.data import transforms as T
import detectron2.data.detection_utils as utils
import argparse
from pycocotools.coco import COCO
import cv2
from detectron2.structures import BoxMode

def argument_parser(epilog=None):
    parser = argparse.ArgumentParser(description='manual to this script')
    ################################# detectron2 params ####################################################
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument("--config-file",
                        default="./COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml",
                        metavar="FILE", help="path to config file")
    return parser
# 配置cfg文件
cfg = get_cfg()
args = argument_parser().parse_args()
cfg.merge_from_file(args.config_file)
# 设置对比度数据增强
cfg.INPUT.RANDOM_CONTRAST.ENABLED = True
cfg.INPUT.RANDOM_CONTRAST.SACLE_MIN = 0.3
cfg.INPUT.RANDOM_CONTRAST.SACLE_MAX = 2
cfg.INPUT.MIN_SIZE_TRAIN = (400, 672, 800, 1000)
is_train = True
print(cfg)
aColor = [(0, 255, 0, 0), (255, 0, 0, 0), (0, 0, 255, 0), (0, 255, 255, 0)]

DATA_DIR = "./Data/coco/2017//"
# 读取coco数据集
jsonFile = DATA_DIR + "/annotations/instances_val2017.json"
coco = COCO(jsonFile)
img_ids = sorted(coco.imgs.keys())
# print(img_ids)
coco.info()
img_id = img_ids[5]
dictImg = coco.loadImgs(img_id)
strImgName = dictImg[0]["file_name"]
annIds = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anns = coco.loadAnns(annIds)
img_path = DATA_DIR + "/val2017/" + strImgName

# 从cfg中 定义数据增强方法
augmentations = utils.build_augmentation(cfg, is_train)
print("1:", augmentations)
augmentations = T.AugmentationList(augmentations)
print("2:", augmentations)

matImg = utils.read_image(img_path, format="BGR")
for i in range(50):  # 循环50次,看50次数据操作后的结果
    aug_input = T.AugInput(matImg)
    transforms = augmentations(aug_input)
    # 对数据进行在线数据增强
    image, seg = aug_input.image, aug_input.sem_seg
    # 对标签进行转换
    for n in range(len(anns)):
        # 坐标转换,将XYWH型的框转换成XYXY
        bbox = BoxMode.convert(anns[n]["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
        # 数据增强函数
        transformed_bbox = transforms.apply_box(bbox)
        # 为了了解XYWH和XYXY的区别,再将坐标转换回来,利用cv画在原图上。
        transformed_bbox = BoxMode.convert(transformed_bbox, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
        [x, y, w, h] = [transformed_bbox[0][0], transformed_bbox[0][1], transformed_bbox[0][2], transformed_bbox[0][3]]
        x, y, w, h = int(x), int(y), int(w), int(h)
        cv2.rectangle(image, (x, y), (x + w, y + h), aColor[anns[n]["category_id"]%4])
        cv2.putText(image, str(anns[n]["category_id"]), (x, y), cv2.FONT_ITALIC, 0.7,
                    aColor[anns[n]["category_id"]%4], 2)
    cv2.imshow("name", image)
    cv2.waitKey(800)

if __name__ == "__main__":

    print("end ... ")

三、Detectron2官方对数据增强部分的描述

为了文档完整,也为了下次方便阅读,这里补充以下官方文档中对数据增强的解释。
版权提示,以下摘自
Detectron2-数据增强(Data Augmentation)官方文档中文翻译

数据增强是训练环节重要的一环。Detectron2 的数据增强系统旨在达到以下的目标:

  • 允许将多个数据类型叠在一起(例如:图像和他们的边框或者 mask
  • 允许应用一个静态声明的递增序列
  • 允许添加自定义的新数据类型来增强(旋转的边框,视频剪辑等)
  • 处理和操作扩充应用的操作

前两个特征覆盖了大部分的普通使用案例,在像albumentations这样的 library 中也是可获得的,对其他特性的支持增加了detectron2增强API的开销,我们下面会进行解释。

本教程重点介绍在编写新的 data loader 时如何使用扩充,以及如何编写新的扩充,如果你在 detectron2 使用了默认的
data loader ,它已经支持了获取用户提供的自定义扩展列表,就像在Detectron2-使用自定义Data Loader(Use Custom Data Loader)官方文档中文翻译中解释的。

基本用法

特征1,2的基本用法如下:

from detectron2.data import transforms as T
# Define a sequence of augmentations:
augs = T.AugmentationList([
    T.RandomBrightness(0.9, 1.1),
    T.RandomFlip(prob=0.5),
    T.RandomCrop("absolute", (640, 640))
])  # type: T.Augmentation

# Define the augmentation input ("image" required, others optional):
input = T.AugInput(image, boxes=boxes, sem_seg=sem_seg)
# Apply the augmentation:
transform = augs(input)  # type: T.Transform
image_transformed = input.image  # new image
sem_seg_transformed = input.sem_seg  # new semantic segmentation

# For any extra data that needs to be augmented together, use transform, e.g.:
image2_transformed = transform.apply_image(image2)
polygons_transformed = transform.apply_polygons(polygons)

这里设计了三个基本概念,他们是:

  • T.Augmentation 定义了修改输入的 “policy”
    -它的 call(AugInput) -> Transform 方法对应地增加了输入,返回了已经应用的操作。
  • T.Transform 实现转换数据的实际操作
    • 它有像 apply_image, apply_coords 这样的办法去定义如何转换每种数据类型。
  • T.Auglnput 储存了 T.Augmentation 需要的输入,和他们应该转换成的输入。一些高级用法中需要这个高级的概念,直接使用这个 class 可以满足大部分的案例,因为 T.AugInput 之外的额外数据可以使用返回的转换进行扩充,如上面的示例所示。

写一个新的增强

大部分二维的增强只需要知道有关输入的图像,这样的增强可以通过如下的方式轻易实现:

class MyColorAugmentation(T.Augmentation):
    def get_transform(self, image):
        r = np.random.rand(2)
        return T.ColorTransform(lambda x: x * r[0] + r[1] * 10)

class MyCustomResize(T.Augmentation):
    def get_transform(self, image):
        old_h, old_w = image.shape[:2]
        new_h, new_w = int(old_h * np.random.rand()), int(old_w * 1.5)
        return T.ResizeTransform(old_h, old_w, new_h, new_w)

augs = MyCustomResize()
transform = augs(input)

除了图像,给定 AugInput 的任何属性都可以使用,只要它们是函数签名的一部分,例如:

class MyCustomCrop(T.Augmentation):
    def get_transform(self, image, sem_seg):
        # decide where to crop using both image and sem_seg
        return T.CropTransform(...)

augs = MyCustomCrop()
assert hasattr(input, "image") and hasattr(input, "sem_seg")
transform = augs(input)

更多关于Detectron2中关于数据增强的描述,见
Detectron2-数据增强(Data Augmentation)官方文档中文翻译
https://detectron2.readthedocs.io/tutorials/augmentation.html

你可能感兴趣的:(深度框架,Pytorch)