[拆轮子] PaddleDetection 中的 config 加载过程

当前 PaddleDetection 版本 2.5,氵这篇博客的原因是我很好奇为什么 最终加载的 config 中会合并 global_config 于是单步调试了一下

tools/eval.py | tools/train.py 中,与配置有关的部分在 main 函数部分:

    FLAGS = parse_args()
    cfg = load_config(FLAGS.config)
    merge_args(cfg, FLAGS)
    merge_config(FLAGS.opt)

FLAGS = parse_args() 部分用来借助 ppdet/utils/cli.py 中的 ArgsParser 进行命令行参数的解析,而 ArgsParser 继承自 argparse.ArgumentParser 用来解析参数

class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='*', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=', 1)
            if '.' not in k:
                config[k] = yaml.load(v, Loader=yaml.Loader)
            else:
                keys = k.split('.')
                if keys[0] not in config:
                    config[keys[0]] = {}
                cur = config[keys[0]]
                for idx, key in enumerate(keys[1:]):
                    if idx == len(keys) - 2:
                        cur[key] = yaml.load(v, Loader=yaml.Loader)
                    else:
                        cur[key] = {}
                        cur = cur[key]
        return config

ArgsParser__init__ 中实现了解析参数 --config--opt 的功能,前者用来指定模型配置文件,后者用来在命令行修改前者配置问价中的内容

_parse_opt 方法解析 --opt 的内容并返回,被 parse_args 调用

FLAGS = parse_args() 执行之后,FLAGS.config 是配置的路径

cfg = load_config(FLAGS.config) # 加载配置的路径

FLAGS.config 是 deformable-detr config 为例:

configs/deformable_detr/deformable_detr_r50_1x_coco.yml
_BASE_: [
  '../datasets/coco_detection.yml',
  '../runtime.yml',
  '_base_/deformable_optimizer_1x.yml',
  '_base_/deformable_detr_r50.yml',
  '_base_/deformable_detr_reader.yml',
]
weights: output/deformable_detr_r50_1x_coco
find_unused_parameters: True

config 文件中 _BASE_ 列表中内容会被读取,之后该 config 中内容会覆盖掉之前的配置(如果有相同的配置)

以下是 load_config 源码:

def load_config(file_path):
    """
    Load config from file.

    Args:
        file_path (str): Path of the config file to be loaded.

    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"

    # load config from file and merge into global config
    cfg = _load_config_with_base(file_path)
    cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
    merge_config(cfg)

    return global_config

load_config 调用 _load_config_with_base 函数去加载配置

# parse and load _BASE_ recursively
def _load_config_with_base(file_path):
    with open(file_path) as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)

    # NOTE: cfgs outside have higher priority than cfgs in _BASE_
    if BASE_KEY in file_cfg:
        all_base_cfg = AttrDict()
        base_ymls = list(file_cfg[BASE_KEY])
        for base_yml in base_ymls:
            if base_yml.startswith("~"):
                base_yml = os.path.expanduser(base_yml)
            if not base_yml.startswith('/'):
                base_yml = os.path.join(os.path.dirname(file_path), base_yml)

            with open(base_yml) as f:
                base_cfg = _load_config_with_base(base_yml)
                all_base_cfg = merge_config(base_cfg, all_base_cfg)

        del file_cfg[BASE_KEY]
        return merge_config(file_cfg, all_base_cfg)

    return file_cfg

BASE_KEY 变量就是 BASE_KEY = '_BASE_'

之前提到过,config中可能会在 _BASE_ 中写一些基础的配置文件,所以 _load_config_with_base 这个加载配置的函数中会递归调用自己来加载配置

此处先加载基础配置

    with open(file_path) as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)

之后for循环读取配置文件并合并,最后也返回

		all_base_cfg = AttrDict()
        for base_yml in base_ymls:
				......

            with open(base_yml) as f:
                base_cfg = _load_config_with_base(base_yml)
                all_base_cfg = merge_config(base_cfg, all_base_cfg)

		del file_cfg[BASE_KEY] # 删除 _BASE_ 列表
        return merge_config(file_cfg, all_base_cfg)

merge_config 中也有值得一说的部分,

def merge_config(config, another_cfg=None):
    """
    Merge config into global config or another_cfg.

    Args:
        config (dict): Config to be merged.

    Returns: global config
    """
    global global_config
    dct = another_cfg or global_config
    return dict_merge(dct, config)

如果 another_cfg{},则先和 global_config 合并,所以 eval.py 和 train.py 中的配置会含有之间注册过的模块的配置,只不过模块的配置都是默认的

dict_merge 中,会将 merge_dct 合并到 dct

def dict_merge(dct, merge_dct):
    """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    
    将 merge_dct 中的元素放入 dct.

    Args:
        dct: dict onto which the merge is executed
        merge_dct: dct merged into dct

    Returns: dct
    """
    for k, v in merge_dct.items():
        if (k in dct and isinstance(dct[k], dict) and
                isinstance(merge_dct[k], collectionsAbc.Mapping)):
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]
    return dct

再说一句 -o 指定指定外部配置的合并过程:

FLAGS = parse_args() 

内部的 ArgsParser 类调用 _parse_opt 方法,做这样一个操作args.opt = self._parse_opt(args.opt)

之后执行 merge_config(FLAGS.opt)FLAGS.optglobal_config 合并

最后调用 check_config(cfg) 来检查配置是否正确

def check_config(cfg):
    """
    Check the correctness of the configuration file. Log error and exit
    when Config is not compliant.
    """
    err = "'{}' not specified in config file. Please set it in config file."
    check_list = ['architecture', 'num_classes']
    try:
        for var in check_list:
            if not var in cfg:
                logger.error(err.format(var))
                sys.exit(1)
    except Exception as e:
        pass

    if 'log_iter' not in cfg:
        cfg.log_iter = 20

    return cfg

该函数检查 cfg 中是否有 'architecture', 'num_classes' 这两项配置,如果无则直接退出

附录:ppdet/core/workspace.py 相关配置代码

global_config = AttrDict()

BASE_KEY = '_BASE_'


# parse and load _BASE_ recursively
def _load_config_with_base(file_path):
    with open(file_path) as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)

    # NOTE: cfgs outside have higher priority than cfgs in _BASE_
    if BASE_KEY in file_cfg:
        all_base_cfg = AttrDict()
        base_ymls = list(file_cfg[BASE_KEY])
        for base_yml in base_ymls:
            if base_yml.startswith("~"):
                base_yml = os.path.expanduser(base_yml)
            if not base_yml.startswith('/'): 
                base_yml = os.path.join(os.path.dirname(file_path), base_yml)

            with open(base_yml) as f:
                base_cfg = _load_config_with_base(base_yml)
                all_base_cfg = merge_config(base_cfg, all_base_cfg)

        del file_cfg[BASE_KEY]
        return merge_config(file_cfg, all_base_cfg)

    return file_cfg


def load_config(file_path):
    """
    Load config from file.

    Args:
        file_path (str): Path of the config file to be loaded.

    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"

    # load config from file and merge into global config
    cfg = _load_config_with_base(file_path)
    cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
    merge_config(cfg)

    return global_config


def dict_merge(dct, merge_dct):
    """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
    updating only top-level keys, dict_merge recurses down into dicts nested
    to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
    ``dct``.
    
    将 merge_dct 中的元素放入 dct.

    Args:
        dct: dict onto which the merge is executed
        merge_dct: dct merged into dct

    Returns: dct
    """
    for k, v in merge_dct.items():
        if (k in dct and isinstance(dct[k], dict) and
                isinstance(merge_dct[k], collectionsAbc.Mapping)):
            dict_merge(dct[k], merge_dct[k])
        else:
            dct[k] = merge_dct[k]
    return dct


def merge_config(config, another_cfg=None):
    """
    Merge config into global config or another_cfg.

    Args:
        config (dict): Config to be merged.

    Returns: global config
    """
    global global_config
    dct = another_cfg or global_config
    return dict_merge(dct, config)

你可能感兴趣的:(每日一氵,PaddleDetection,paddlepaddle历险记,python,人工智能,java)