Hook 是什么?在 wiki 百科中定义如下:
钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)
下面我们来看一下,在mmdetection源码中,hook的整个使用流程
1、实例化注册表HOOKS,构建全局的 HOOKS 注册器类
##只是设置了self._name = 'hook',self.build_func = build_from_cfg
HOOKS = Registry('hook')
2、python装饰器函数,将类名和类添加到HOOKS._module_dict中,类似下面这种语法。
#通过装饰器方式 把key-value添加到字典中self._module_dict[name] = module_class
@HOOKS.register_module()
class CheckpointHook(Hook):
#这里需要注意,所有的这些子类都继承自Hook,Hook定义了不同阶段名称和对应的函数
3、通过下面引入操作,执行装饰器函数
from .checkpoint import CheckpointHook
4、5、6、7用配置参数去实例化Hook子类
4、在train_detector函数中,开始训练之前,出现了runner.register_training_hooks这个函数的调用:
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
5、runner.register_training_hooks这个函数体内分别注册多个hook,具体实现如下:
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
6、我们再以第一个函数self.register_lr_hook为例,看一下具体注册过程,主要步骤就是先利用配置文件和HOOKS去实例化这个类,然后
def register_lr_hook(self, lr_config):
if lr_config is None:
return
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
#首字母转为大写
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
#我的例程中hook_type= 'StepLrUpdaterHook'
lr_config['type'] = hook_type
#HOOKS = Registry('hook')是之前注册好的,在 mmcv.build_from_cfg函数中,通过类名获得类,
#然后再用lr_config里面的参数对类进行实例化。class StepLrUpdaterHook(LrUpdaterHook):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
#把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
self.register_hook(hook, priority='VERY_HIGH')
7、下面一段伪代码build_from_cfg,展示了类实例化的主要过程。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
obj_type = args.pop('type')
obj_cls = registry.get(obj_type)
return obj_cls(**args)
8、把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
def register_hook(self, hook, priority='NORMAL'):
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
9、call_hook,钩子函数的调用
self.call_hook('before_train_epoch')
#遍历 self._hooks列表中的类(这些类是按照优先级排序过的),按照fn_name获取hook类中对应的函数,再执行,class Hook是这些类的父类,这些函数在父类中声明了,但并没有实现。包括before_run,after_run,before_epoch,after_epoch,before_iter,after_iter,before_train_epoch,before_val_epoch,after_train_epoch,after_val_epoch等等…
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
总结:在 python 中由于函数是一等公民,实现 hook 机制其实只需要传入一个函数即可。在mmdetection源码中,我们通过call_hook函数,实际上传入多个提前注册好的函数,实现了模型参数保存、学习率调度、梯度反向传播加上参数更新等功能。
1、实例化HOOKS
2、装饰器函数,将类名和类添加到HOOKS._module_dict
3、配置参数去实例化Hook子类
4、把实例好的Hook子类按照优先级顺序插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
5、调用钩子函数的,根据函数名去self._hook查找调用