Trainer(cfg)
类trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
Trainer
类是由DefaultTrainer
类派生而来class Trainer(DefaultTrainer):
DefaultTrainer
类class DefaultTrainer(TrainerBase)
通过cfg文件,使用model、optimizer、dataloader来定义了一个具有默认训练逻辑的类,这里的dataloader就是数据训练器中数据获取部分。
扩展这个类,因为篇幅关系,只展开训练器中获取网络、优化器、以及本文关注的数据获取部分
class DefaultTrainer(TrainerBase):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super().__init__()
...
# Assume these objects must be constructed in this order.
# 这里就是通过cfg来获取网络、优化器、数据
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
...
self.build_train_loader(cfg)
函数 def build_train_loader(cls, cfg):
return build_detection_train_loader(cfg)
发现,直接调用build_detection_train_loader(cfg)
,这个返回的就是数据的一个迭代器
build_detection_train_loader
函数def build_detection_train_loader(
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):
"""
Build a dataloader for object detection with some default features.
if isinstance(dataset, list):
dataset = DatasetFromList(dataset, copy=False)
if mapper is not None:
dataset = MapDataset(dataset, mapper)
if sampler is None:
sampler = TrainingSampler(len(dataset))
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
return build_batch_data_loader(
dataset,
sampler,
total_batch_size,
aspect_ratio_grouping=aspect_ratio_grouping,
num_workers=num_workers,
)
这个函数就是为目标检测任务创建一个 data_loader
可以看到对dataset有两个操作,一个是DatasetFromList
,一个是MapDataset
。
MapDataset
函数至关重要,要想增加在线数据增强操作,就得定义一个mappper,adelaidet里面的写法如下:
def build_train_loader(cls, cfg):
mapper = DatasetMapperWithBasis(cfg, True)
return build_detection_train_loader(cfg, mapper)
这个mapper里包含了一些自定义的操作。
如果想看mapper里的具体写法,可以去了解adelaidet或者detectron2中
.data.dataset_mapper写法的区别
这里笔者将detectron2.data
里的内容抽取出来,并制作成一个demo。
官方默认在线数据增强模块中,只有Resize模块,这里添加了对比度变换的数据增强。
detectron2.data.transforms
里,官方已经定义了很多常规的数据增强方式
我们这里对比度变化的数据增强同样采用
augmentataion_impl
里的 RandomContrast
函数
class RandomContrast(Augmentation):
def __init__(self, intensity_min, intensity_max):
"""
Args:
intensity_min (float): Minimum augmentation
intensity_max (float): Maximum augmentation
"""
super().__init__()
self._init(locals())
RandomContrast中有两个参数,分别对应对比度调整最小值,对比度调整最大值,如果我们需要使用这个数据增强方式,就得在配置cfg文件中配置这两个参数
_C.INPUT.RANDOM_CONTRAST = CN({"ENABLED": False})
_C.INPUT.RANDOM_CONTRAST.SACLE_MIN = 1
_C.INPUT.RANDOM_CONTRAST.SACLE_MAX = 1
然后在detectron2.data.detection_utils.py
中的build_augmentation
中添加配置文件的解析字段,博主完整如下:
def build_augmentation(cfg, is_train):
"""
Create a list of default :class:`Augmentation` from config.
Now it includes resizing and flipping.
Returns:
list[Augmentation]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
if is_train and cfg.INPUT.RANDOM_FLIP != "none":
augmentation.append(
T.RandomFlip(
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
)
)
# add contrast data augmentation by hjxu in 2020.11.30
if is_train and cfg.INPUT.RANDOM_CONTRAST.ENABLED:
augmentation.append(
T.RandomContrast(
intensity_min=cfg.INPUT.RANDOM_CONTRAST.SACLE_MIN,
intensity_max=cfg.INPUT.RANDOM_CONTRAST.SACLE_MAX,
)
)
return augmentation
augmentations = utils.build_augmentation(cfg, is_train)
print("1:", augmentations)
augmentations = T.AugmentationList(augmentations)
print("2:", augmentations)
打印如下,为了美观,这里删除部分部分信息
1: [ResizeShortestEdge.., RandomFlip(), RandomContrast(intensity_min=0.3, intensity_max=2)]
2: AugmentationList[ResizeShortestEdge..., RandomFlip(), RandomContrast(intensity_min=0.3, intensity_max=2)]
0、配置cfg文件
1、读取coco数据集
2、定义 augmentations
3、利用augmentations对图像进行增强操作
4、利用transforms.apply_box(bbox)对目标框进行操作
5、利用opencv显示图像
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
* * *** * * * *
* * * ** * *
**** * ** * *
* * * ** * *
* * ** * * ****
@File :detectron2/data_aug_demo.py
@Date :2020/12/1 上午11:58
@Require :
@Author :hjxu2016
@Funtion :
"""
from detectron2.config import get_cfg
from detectron2.data import transforms as T
import detectron2.data.detection_utils as utils
import argparse
from pycocotools.coco import COCO
import cv2
from detectron2.structures import BoxMode
def argument_parser(epilog=None):
parser = argparse.ArgumentParser(description='manual to this script')
################################# detectron2 params ####################################################
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument("--config-file",
default="./COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml",
metavar="FILE", help="path to config file")
return parser
# 配置cfg文件
cfg = get_cfg()
args = argument_parser().parse_args()
cfg.merge_from_file(args.config_file)
# 设置对比度数据增强
cfg.INPUT.RANDOM_CONTRAST.ENABLED = True
cfg.INPUT.RANDOM_CONTRAST.SACLE_MIN = 0.3
cfg.INPUT.RANDOM_CONTRAST.SACLE_MAX = 2
cfg.INPUT.MIN_SIZE_TRAIN = (400, 672, 800, 1000)
is_train = True
print(cfg)
aColor = [(0, 255, 0, 0), (255, 0, 0, 0), (0, 0, 255, 0), (0, 255, 255, 0)]
DATA_DIR = "./Data/coco/2017//"
# 读取coco数据集
jsonFile = DATA_DIR + "/annotations/instances_val2017.json"
coco = COCO(jsonFile)
img_ids = sorted(coco.imgs.keys())
# print(img_ids)
coco.info()
img_id = img_ids[5]
dictImg = coco.loadImgs(img_id)
strImgName = dictImg[0]["file_name"]
annIds = coco.getAnnIds(imgIds=img_id, iscrowd=None)
anns = coco.loadAnns(annIds)
img_path = DATA_DIR + "/val2017/" + strImgName
# 从cfg中 定义数据增强方法
augmentations = utils.build_augmentation(cfg, is_train)
print("1:", augmentations)
augmentations = T.AugmentationList(augmentations)
print("2:", augmentations)
matImg = utils.read_image(img_path, format="BGR")
for i in range(50): # 循环50次,看50次数据操作后的结果
aug_input = T.AugInput(matImg)
transforms = augmentations(aug_input)
# 对数据进行在线数据增强
image, seg = aug_input.image, aug_input.sem_seg
# 对标签进行转换
for n in range(len(anns)):
# 坐标转换,将XYWH型的框转换成XYXY
bbox = BoxMode.convert(anns[n]["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
# 数据增强函数
transformed_bbox = transforms.apply_box(bbox)
# 为了了解XYWH和XYXY的区别,再将坐标转换回来,利用cv画在原图上。
transformed_bbox = BoxMode.convert(transformed_bbox, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
[x, y, w, h] = [transformed_bbox[0][0], transformed_bbox[0][1], transformed_bbox[0][2], transformed_bbox[0][3]]
x, y, w, h = int(x), int(y), int(w), int(h)
cv2.rectangle(image, (x, y), (x + w, y + h), aColor[anns[n]["category_id"]%4])
cv2.putText(image, str(anns[n]["category_id"]), (x, y), cv2.FONT_ITALIC, 0.7,
aColor[anns[n]["category_id"]%4], 2)
cv2.imshow("name", image)
cv2.waitKey(800)
if __name__ == "__main__":
print("end ... ")
为了文档完整,也为了下次方便阅读,这里补充以下官方文档中对数据增强的解释。
版权提示,以下摘自
Detectron2-数据增强(Data Augmentation)官方文档中文翻译
数据增强是训练环节重要的一环。Detectron2 的数据增强系统旨在达到以下的目标:
前两个特征覆盖了大部分的普通使用案例,在像albumentations这样的 library 中也是可获得的,对其他特性的支持增加了detectron2增强API的开销,我们下面会进行解释。
本教程重点介绍在编写新的 data loader 时如何使用扩充,以及如何编写新的扩充,如果你在 detectron2 使用了默认的
data loader ,它已经支持了获取用户提供的自定义扩展列表,就像在Detectron2-使用自定义Data Loader(Use Custom Data Loader)官方文档中文翻译中解释的。
特征1,2的基本用法如下:
from detectron2.data import transforms as T
# Define a sequence of augmentations:
augs = T.AugmentationList([
T.RandomBrightness(0.9, 1.1),
T.RandomFlip(prob=0.5),
T.RandomCrop("absolute", (640, 640))
]) # type: T.Augmentation
# Define the augmentation input ("image" required, others optional):
input = T.AugInput(image, boxes=boxes, sem_seg=sem_seg)
# Apply the augmentation:
transform = augs(input) # type: T.Transform
image_transformed = input.image # new image
sem_seg_transformed = input.sem_seg # new semantic segmentation
# For any extra data that needs to be augmented together, use transform, e.g.:
image2_transformed = transform.apply_image(image2)
polygons_transformed = transform.apply_polygons(polygons)
这里设计了三个基本概念,他们是:
大部分二维的增强只需要知道有关输入的图像,这样的增强可以通过如下的方式轻易实现:
class MyColorAugmentation(T.Augmentation):
def get_transform(self, image):
r = np.random.rand(2)
return T.ColorTransform(lambda x: x * r[0] + r[1] * 10)
class MyCustomResize(T.Augmentation):
def get_transform(self, image):
old_h, old_w = image.shape[:2]
new_h, new_w = int(old_h * np.random.rand()), int(old_w * 1.5)
return T.ResizeTransform(old_h, old_w, new_h, new_w)
augs = MyCustomResize()
transform = augs(input)
除了图像,给定 AugInput 的任何属性都可以使用,只要它们是函数签名的一部分,例如:
class MyCustomCrop(T.Augmentation):
def get_transform(self, image, sem_seg):
# decide where to crop using both image and sem_seg
return T.CropTransform(...)
augs = MyCustomCrop()
assert hasattr(input, "image") and hasattr(input, "sem_seg")
transform = augs(input)
更多关于Detectron2中关于数据增强的描述,见
Detectron2-数据增强(Data Augmentation)官方文档中文翻译
https://detectron2.readthedocs.io/tutorials/augmentation.html