基于detectron2框架的深度学习模型载入自定义数据集

基于detectron2框架的深度学习模型载入自定义数据集

一、前言

最近在做微光目标检测的研究工作,使用了Rank_DETR;这个模型是基于detrex框架,而detrex框架又是基于detectron2的。找了一圈没找到载入数据集的地方,后面查阅了资料得知要用API进行注册。

二、步骤

  1. 注册数据集:
    在脚本中,我们首先要注册数据集。Detectron2 提供了多种注册数据集的方式,常用的是 register_coco_instances,用于 COCO 格式的数据集。您可以在脚本的开头或配置文件中添加如下代码来注册您的数据集:

    from detectron2.data.datasets import register_coco_instances
    
    register_coco_instances("my_dataset_train", {}, "path/to/train_annotations.json", "path/to/train_images/")
    register_coco_instances("my_dataset_val", {}, "path/to/val_annotations.json", "path/to/val_images/")
    
    • "my_dataset_train""my_dataset_val" 是数据集的名称,您可以按需更改。
    • path/to/train_annotations.jsonpath/to/val_annotations.json 分别是训练和验证数据集的 COCO 格式标注文件路径。
    • path/to/train_images/path/to/val_images/ 是训练和验证图像的路径。
  2. 在配置文件中引用数据集:
    在您使用的配置文件中,需要确保数据加载器 (dataloader) 中引用了您刚才注册的数据集。通常,您需要修改以下内容:

    cfg.dataloader.train.dataset.names = "my_dataset_train"
    cfg.dataloader.test.dataset.names = "my_dataset_val"
    

    这确保了训练和验证时使用的是您自定义的数据集。

三、示例代码集成

如果您已经在脚本中集成了以上步骤,代码可能如下所示:

def main(args):
    cfg = LazyConfig.load(args.config_file)
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args)
    register_coco_instances("exdark_train", {},
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_train2017.json",
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/train2017")
    register_coco_instances("exdark_test", {},
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_val2017.json",
                            "/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/val2017")
    cfg.dataloader.train.dataset.names = "exdark_train"
    cfg.dataloader.test.dataset.names = "exdark_test"

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
        print(do_test(cfg, model))
    else:
        do_train(args, cfg)


if __name__ == "__main__":
    parser = default_argument_parser()
    parser.add_argument("--use_wandb", action="store_true", help="Whether to use wandb.")
    parser.add_argument("--wandb_key", type=str, help="Wandb API key.")
    args = parser.parse_args()

    if args.use_wandb:
        wandb.login(key=args.wandb_key)
        
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

你可能感兴趣的:(解决方案,深度学习,人工智能,计算机视觉)