使用Detectron2训练自己的网络 (一)

首先上官方文档,本文希望能给看文档比较费劲的萌新一些帮助。

https://detectron2.readthedocs.io/tutorials/datasets.html

(一)使用自定义数据集

【1】网上很多教程是将数据集转化为coco数据集格式,然后直接使用detectron2自带的方法解析数据集。

coco解析数据集的方法位置:detectron2/data/datasets/coco.py的load_coco_json。

该种方法我就不做过多介绍了,网上很多。

【2】自己定义解析数据集解析方法(非coco)

(1)首先需要注册一下新数据集名字

def get_dicts():
    #数据处理部分
    return list[dict]  
from detectron2.data import DatasetCatalog
DatasetCatalog.register('train数据集名字',get_dicts)
DatasetCatalog.register('val数据集名字', get_val_dicts)
#get_refer_dicts是对应解析数据集时需要调用的函数

注意,由于get_dicts()会将数据集的所有信息都读入内存中,所以比如读取图片、预处理的一些操作是不写在这里的(放在Mapper中,后面会说)。返回是一个list[dict]的数据集信息,其中每一个dict表示一张图片的信息,数据集中所有图片信息存在一个list里面。

(2)对于数据集中一些公共信息,比如说类别编号、id映射等信息,detectron2也是封装好的。通过MetaData来保存这些重复内容,例如(这部分就需要读一下api)

from detectron2.data import MetadataCatalog
MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]

(二)数据集信息解析----自定义数据加载管道

再将数据信息传入网络前一般可能需要数据增强,图片读取等一系列操作,这部分主要有Mapper实现。

detectron2的默认方法位置:detectron2/data/dataset_mapper.py的__call__()方法

detectron2提供了一些数据处理的方法,比如:图片读取detectron.utils.read_image(),可以用给的方法。返回的仍然是list[dict]

【1】如果使用detectron2给定的网络,比如(rcnn、retinanet)就需要mapper函数中返回数据具有一些必须的key(instances,image),这部分建议于都dataset_mapper源码。

from detectron2.data import build_detection_train_loader
from detectron2.data import transforms as T
from detectron2.data import detection_utils as utils

def mapper(dataset_dict):
	# Implement a mapper, similar to the default DatasetMapper, but with your own customizations
	dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
	image = utils.read_image(dataset_dict["file_name"], format="BGR")
	image, transforms = T.apply_transform_gens([T.Resize((800, 800))], image)
	dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

	annos = [
		utils.transform_instance_annotations(obj, transforms, image.shape[:2])
		for obj in dataset_dict.pop("annotations")
		if obj.get("iscrowd", 0) == 0
	]
	instances = utils.annotations_to_instances(annos, image.shape[:2])
	dataset_dict["instances"] = utils.filter_empty_instances(instances)
	return dataset_dict

data_loader = build_detection_train_loader(cfg, mapper=mapper)
# use this dataloader instead of the default

【2】如果使用自定义的网络框架,那么你需要啥就定义啥咯~

(三)网络定义

由于detectron2使用的是可以直接通过配置文件完成网络定义的模式。所以整个网络框架可以通过看configs理清条理。首先对于configs中没有的参数,detectron2是读取默认参数。

默认参数位置: detectron2/config/defaults.py

【1】META_ARCHITECTURE

首先会调用meta_architecture文件,位置:detectron2/modeling/meta_arch文件夹下。

写过pytorch的朋友都知道,网络的内容在forward函数中。传入参数batched_inputs就是前面mapper return的内容。其中batched_inputs是一个betch个数的dict.然后就可以进行各种网络操作啦~

【2】其他部分,将在(二)中继续,如果有什么问题可以留言~~

 

推荐其他大佬的博客:

http://objectdetection.cn/2020/03/04/detectron2-%e5%ae%89%e8%a3%85%e6%95%99%e7%a8%8b/

https://blog.csdn.net/weixin_43013761/article/details/104022615

 

 

你可能感兴趣的:(pytorch)