【Detectron2】代码库学习-4. LazyConfig 配置文件

【Detectron2】代码库学习-4. LazyConfig 配置文件_第1张图片

目录

    • 1. 配置文件
    • 2. LazyConfig 导入导出
    • 3. 递归实例化
    • 4. 基于LazyConfig的训练步骤
      • 4.1 导入依赖库
      • 4.2 日志初始化
      • 4.3 训练
      • 4.4 评估
      • 4.5 训练流程
      • 4.6 主函数入口
      • 5. Tips

Detectron2是Facebook AI Research(FAIR)推出的基于Pytorch的视觉算法开源框架,主要聚焦于目标检测和分割任务等视觉算法,此外还支持全景分割,关键点检测,旋转框检测等任务。Detectron2继承自Detectron 和mask-rcnn。
Detectron2具有较强的灵活性和可扩展性,支持快速的单GPU训练,多GPU并行训练和多节点分布式训练。

1. 配置文件

Detectron2 原本采用的是基于一种 key-value的基础config 系统, 采用 YAML格式。但是YAML是一种非常受限制的语言,不能存储复杂的数据结构,因此转而使用 一种更强大的 配置文件系统 LazyConfig system。

YAML-维基百科 是一种人类可读、数据可序列化(可保持成文件和重新加载恢复)的语言, JSON 格式也是一种合法的YAML。原始 的YAML只支持编码 标量(字符串,整数,浮点数)和关系数组(map, 字典,hash表)。YAML推荐的后缀名为.yaml

2. LazyConfig 导入导出

直接采用 python 脚本作为 配置文件载体,可以通过 python代码快速操作。支持丰富的数据类型。可以运行简单的函数。通过python的import语法导入导出。
config_test.py

inputs = [1024, 960]  # 输入大小
batch_size = 128
train_dict = {"input": inputs, "batch_size": batch_size}

通过 detectron2 提供的API 加载配置文件。方便获取属性和配置, 但是代码无法补全

from detectron2.config import LazyConfig
cfg=LazyConfig.load("config_test.py")
print(cfg.train_dict.batch_size)  # 方便获取属性和配置, 但是代码无法补全,
LazyConfig.save(cfg, "test.yaml") # 导出配置到yaml文件, 部分无法序列化的数据类型不能保存,如numpy 数组

test.yaml

train_dict:
  batch_size: 128
  input: [1024, 960]

3. 递归实例化

LazyConfig 采用递归实例化 特性,将函数和类的调用表示为字典。在调用时并不会立即执行 对应的函数,只返回一个字典 描述这个 call, 只有在实例化时才真正执行。

from detectron2.config import instantiate, LazyCall
import torch.nn as nn
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)  # 调用nn.Conv2d, 并配置参数
layer_cfg.out_channels = 64   # can edit it afterwards , 修改 参数
layer = instantiate(layer_cfg)  # 实例化对象,创建一个2维卷积层

LazyCall

class LazyCall:
    def __init__(self, target):
        self._target = target
    def __call__(self, **kwargs):
        if is_dataclass(self._target):
            # omegaconf object cannot hold dataclass type
            # https://github.com/omry/omegaconf/issues/784
            target = _convert_target_to_string(self._target)
        else:
            target = self._target
        kwargs["_target_"] = target
        return DictConfig(content=kwargs, flags={"allow_objects": True})

instantiate

def instantiate(cfg):
    """
    Recursively instantiate objects defined in dictionaries 
    """
    from omegaconf import ListConfig, DictConfig, OmegaConf
    if isinstance(cfg, ListConfig):
        lst = [instantiate(x) for x in cfg]  # 递归调用
        return ListConfig(lst, flags={"allow_objects": True})
    if isinstance(cfg, list):
        # Specialize for list, because many classes take
        # list[objects] as arguments, such as ResNet, DatasetMapper
        return [instantiate(x) for x in cfg]

    if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
        return OmegaConf.to_object(cfg)

    if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
        # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
        # but faster: https://github.com/facebookresearch/hydra/issues/1200
        cfg = {k: instantiate(v) for k, v in cfg.items()}
        cls = cfg.pop("_target_")
        cls = instantiate(cls)
        if isinstance(cls, str):
            cls_name = cls
            cls = locate(cls_name)
            assert cls is not None, cls_name
        else:
            try:
                cls_name = cls.__module__ + "." + cls.__qualname__
            except Exception:
                # target could be anything, so the above could fail
                cls_name = str(cls)
        assert callable(cls), f"_target_ {cls} does not define a callable object"
        try:
            return cls(**cfg)  ## 根据c
        except TypeError:
            logger = logging.getLogger(__name__)
            logger.error(f"Error when instantiating {cls_name}!")
            raise
    return cfg  # return as-is if don't know what to do

4. 基于LazyConfig的训练步骤

4.1 导入依赖库

import logging
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from detectron2.engine import (
    AMPTrainer, # 自动混合精度训练
    SimpleTrainer, 
    default_argument_parser,
    default_setup, # 默认配置参数
    default_writers,
    hooks,
    launch, # 分布式训练启动器
)
from detectron2.engine.defaults import create_ddp_model
from detectron2.evaluation import inference_on_dataset, print_csv_format
from detectron2.utils import comm

4.2 日志初始化

logger = logging.getLogger("detectron2")

4.3 训练

def do_train(args, cfg):
    model = instantiate(cfg.model)  # 获取模型
    logger = logging.getLogger("detectron2") 
    logger.info("Model:\n{}".format(model)) 
    model.to(cfg.train.device)

    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer) # 获取优化器

    train_loader = instantiate(cfg.dataloader.train)# 获取训练dataloader

    model = create_ddp_model(model, **cfg.train.ddp) # 并行模型
    # 混合精度训练
    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
    checkpointer = DetectionCheckpointer(  # checkpoint 管理
        model,
        cfg.train.output_dir,
        trainer=trainer,
    )
    trainer.register_hooks(  # 注册回调函数
        [
            hooks.IterationTimer(), # 计时器
            hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
            hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
            if comm.is_main_process() # 主进程 周期保存 checkpoint
            else None,
            hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), # 评估
            hooks.PeriodicWriter( # 保存训练日志
                default_writers(cfg.train.output_dir, cfg.train.max_iter),
                period=cfg.train.log_period,
            )
            if comm.is_main_process()
            else None,
        ]
    )

    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)  # 初始化或者恢复训练
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)

4.4 评估

def do_test(cfg, model):
    if "evaluator" in cfg.dataloader:
        ret = inference_on_dataset(
            model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
        )
        print_csv_format(ret)
        return ret

4.5 训练流程

def main(args):
    cfg = LazyConfig.load(args.config_file) 
    cfg = LazyConfig.apply_overrides(cfg, args.opts)
    default_setup(cfg, args) # 默认日志,日志记录基础信息,备份配置文件

    if args.eval_only:
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        model = create_ddp_model(model)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint) # 加载权重
        print(do_test(cfg, model))
    else:
        do_train(args, cfg)

4.6 主函数入口

if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    launch( # 启动多GPU训练
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank, # 当前节点ID 
        dist_url=args.dist_url,
        args=(args,),
    )

5. Tips

  • 像python代码一样操作配置文件,将相同的配置独立出来,导入进来,而不是复制多份
  • 尽可能的保存配置文件的简洁,不需要的不写

你可能感兴趣的:(开源库,Detectron2,学习,python,深度学习)