当前 PaddleDetection 版本 2.5
,氵这篇博客的原因是我很好奇为什么 最终加载的 config 中会合并 global_config
于是单步调试了一下
在 tools/eval.py | tools/train.py
中,与配置有关的部分在 main 函数部分:
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_args(cfg, FLAGS)
merge_config(FLAGS.opt)
FLAGS = parse_args()
部分用来借助 ppdet/utils/cli.py
中的 ArgsParser
进行命令行参数的解析,而 ArgsParser
继承自 argparse.ArgumentParser
用来解析参数
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='*', help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=', 1)
if '.' not in k:
config[k] = yaml.load(v, Loader=yaml.Loader)
else:
keys = k.split('.')
if keys[0] not in config:
config[keys[0]] = {}
cur = config[keys[0]]
for idx, key in enumerate(keys[1:]):
if idx == len(keys) - 2:
cur[key] = yaml.load(v, Loader=yaml.Loader)
else:
cur[key] = {}
cur = cur[key]
return config
ArgsParser
在 __init__
中实现了解析参数 --config
和 --opt
的功能,前者用来指定模型配置文件,后者用来在命令行修改前者配置问价中的内容
_parse_opt
方法解析 --opt
的内容并返回,被 parse_args
调用
FLAGS = parse_args()
执行之后,FLAGS.config
是配置的路径
cfg = load_config(FLAGS.config) # 加载配置的路径
以 FLAGS.config
是 deformable-detr config 为例:
configs/deformable_detr/deformable_detr_r50_1x_coco.yml
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/deformable_optimizer_1x.yml',
'_base_/deformable_detr_r50.yml',
'_base_/deformable_detr_reader.yml',
]
weights: output/deformable_detr_r50_1x_coco
find_unused_parameters: True
config 文件中 _BASE_
列表中内容会被读取,之后该 config 中内容会覆盖掉之前的配置(如果有相同的配置)
以下是 load_config
源码:
def load_config(file_path):
"""
Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
# load config from file and merge into global config
cfg = _load_config_with_base(file_path)
cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
merge_config(cfg)
return global_config
load_config
调用 _load_config_with_base
函数去加载配置
# parse and load _BASE_ recursively
def _load_config_with_base(file_path):
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
# NOTE: cfgs outside have higher priority than cfgs in _BASE_
if BASE_KEY in file_cfg:
all_base_cfg = AttrDict()
base_ymls = list(file_cfg[BASE_KEY])
for base_yml in base_ymls:
if base_yml.startswith("~"):
base_yml = os.path.expanduser(base_yml)
if not base_yml.startswith('/'):
base_yml = os.path.join(os.path.dirname(file_path), base_yml)
with open(base_yml) as f:
base_cfg = _load_config_with_base(base_yml)
all_base_cfg = merge_config(base_cfg, all_base_cfg)
del file_cfg[BASE_KEY]
return merge_config(file_cfg, all_base_cfg)
return file_cfg
BASE_KEY
变量就是 BASE_KEY = '_BASE_'
之前提到过,config中可能会在 _BASE_
中写一些基础的配置文件,所以 _load_config_with_base
这个加载配置的函数中会递归调用自己来加载配置
此处先加载基础配置
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
之后for循环读取配置文件并合并,最后也返回
all_base_cfg = AttrDict()
for base_yml in base_ymls:
......
with open(base_yml) as f:
base_cfg = _load_config_with_base(base_yml)
all_base_cfg = merge_config(base_cfg, all_base_cfg)
del file_cfg[BASE_KEY] # 删除 _BASE_ 列表
return merge_config(file_cfg, all_base_cfg)
而 merge_config
中也有值得一说的部分,
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global global_config
dct = another_cfg or global_config
return dict_merge(dct, config)
如果 another_cfg
为 {}
,则先和 global_config
合并,所以 eval.py 和 train.py 中的配置会含有之间注册过的模块的配置,只不过模块的配置都是默认的
而 dict_merge
中,会将 merge_dct
合并到 dct
中
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
将 merge_dct 中的元素放入 dct.
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collectionsAbc.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
return dct
-o
指定指定外部配置的合并过程:FLAGS = parse_args()
内部的 ArgsParser
类调用 _parse_opt
方法,做这样一个操作args.opt = self._parse_opt(args.opt)
之后执行 merge_config(FLAGS.opt)
将 FLAGS.opt
与 global_config
合并
check_config(cfg)
来检查配置是否正确def check_config(cfg):
"""
Check the correctness of the configuration file. Log error and exit
when Config is not compliant.
"""
err = "'{}' not specified in config file. Please set it in config file."
check_list = ['architecture', 'num_classes']
try:
for var in check_list:
if not var in cfg:
logger.error(err.format(var))
sys.exit(1)
except Exception as e:
pass
if 'log_iter' not in cfg:
cfg.log_iter = 20
return cfg
该函数检查 cfg 中是否有 'architecture', 'num_classes'
这两项配置,如果无则直接退出
global_config = AttrDict()
BASE_KEY = '_BASE_'
# parse and load _BASE_ recursively
def _load_config_with_base(file_path):
with open(file_path) as f:
file_cfg = yaml.load(f, Loader=yaml.Loader)
# NOTE: cfgs outside have higher priority than cfgs in _BASE_
if BASE_KEY in file_cfg:
all_base_cfg = AttrDict()
base_ymls = list(file_cfg[BASE_KEY])
for base_yml in base_ymls:
if base_yml.startswith("~"):
base_yml = os.path.expanduser(base_yml)
if not base_yml.startswith('/'):
base_yml = os.path.join(os.path.dirname(file_path), base_yml)
with open(base_yml) as f:
base_cfg = _load_config_with_base(base_yml)
all_base_cfg = merge_config(base_cfg, all_base_cfg)
del file_cfg[BASE_KEY]
return merge_config(file_cfg, all_base_cfg)
return file_cfg
def load_config(file_path):
"""
Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
# load config from file and merge into global config
cfg = _load_config_with_base(file_path)
cfg['filename'] = os.path.splitext(os.path.split(file_path)[-1])[0]
merge_config(cfg)
return global_config
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
将 merge_dct 中的元素放入 dct.
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collectionsAbc.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
return dct
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global global_config
dct = another_cfg or global_config
return dict_merge(dct, config)