首先上官方文档,本文希望能给看文档比较费劲的萌新一些帮助。
https://detectron2.readthedocs.io/tutorials/datasets.html
(一)使用自定义数据集
【1】网上很多教程是将数据集转化为coco数据集格式,然后直接使用detectron2自带的方法解析数据集。
coco解析数据集的方法位置:detectron2/data/datasets/coco.py的load_coco_json。
该种方法我就不做过多介绍了,网上很多。
【2】自己定义解析数据集解析方法(非coco)
(1)首先需要注册一下新数据集名字
def get_dicts():
#数据处理部分
return list[dict]
from detectron2.data import DatasetCatalog
DatasetCatalog.register('train数据集名字',get_dicts)
DatasetCatalog.register('val数据集名字', get_val_dicts)
#get_refer_dicts是对应解析数据集时需要调用的函数
注意,由于get_dicts()会将数据集的所有信息都读入内存中,所以比如读取图片、预处理的一些操作是不写在这里的(放在Mapper中,后面会说)。返回是一个list[dict]的数据集信息,其中每一个dict表示一张图片的信息,数据集中所有图片信息存在一个list里面。
(2)对于数据集中一些公共信息,比如说类别编号、id映射等信息,detectron2也是封装好的。通过MetaData来保存这些重复内容,例如(这部分就需要读一下api)
from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]
(二)数据集信息解析----自定义数据加载管道
再将数据信息传入网络前一般可能需要数据增强,图片读取等一系列操作,这部分主要有Mapper实现。
detectron2的默认方法位置:detectron2/data/dataset_mapper.py的__call__()方法
detectron2提供了一些数据处理的方法,比如:图片读取detectron.utils.read_image(),可以用给的方法。返回的仍然是list[dict]
【1】如果使用detectron2给定的网络,比如(rcnn、retinanet)就需要mapper函数中返回数据具有一些必须的key(instances,image),这部分建议于都dataset_mapper源码。
from detectron2.data import build_detection_train_loader
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils
def mapper(dataset_dict):
# Implement a mapper, similar to the default DatasetMapper, but with your own customizations
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format="BGR")
image, transforms = T.apply_transform_gens([T.Resize((800, 800))], image)
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
annos = [
utils.transform_instance_annotations(obj, transforms, image.shape[:2])
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(annos, image.shape[:2])
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
data_loader = build_detection_train_loader(cfg, mapper=mapper)
# use this dataloader instead of the default
【2】如果使用自定义的网络框架,那么你需要啥就定义啥咯~
(三)网络定义
由于detectron2使用的是可以直接通过配置文件完成网络定义的模式。所以整个网络框架可以通过看configs理清条理。首先对于configs中没有的参数,detectron2是读取默认参数。
默认参数位置: detectron2/config/defaults.py
【1】META_ARCHITECTURE
首先会调用meta_architecture文件,位置:detectron2/modeling/meta_arch文件夹下。
写过pytorch的朋友都知道,网络的内容在forward函数中。传入参数batched_inputs就是前面mapper return的内容。其中batched_inputs是一个betch个数的dict.然后就可以进行各种网络操作啦~
【2】其他部分,将在(二)中继续,如果有什么问题可以留言~~
推荐其他大佬的博客:
http://objectdetection.cn/2020/03/04/detectron2-%e5%ae%89%e8%a3%85%e6%95%99%e7%a8%8b/
https://blog.csdn.net/weixin_43013761/article/details/104022615