mmdetection - 初识hook的使用

Hook 是什么?在 wiki 百科中定义如下:

钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

下面我们来看一下,在mmdetection源码中,hook的整个使用流程
1、实例化注册表HOOKS,构建全局的 HOOKS 注册器类

##只是设置了self._name = 'hook',self.build_func = build_from_cfg
HOOKS = Registry('hook')

2、python装饰器函数,将类名和类添加到HOOKS._module_dict中,类似下面这种语法。

#通过装饰器方式 把key-value添加到字典中self._module_dict[name] = module_class
@HOOKS.register_module()
class CheckpointHook(Hook):
#这里需要注意,所有的这些子类都继承自Hook,Hook定义了不同阶段名称和对应的函数

3、通过下面引入操作,执行装饰器函数

from .checkpoint import CheckpointHook

4、5、6、7用配置参数去实例化Hook子类
4、在train_detector函数中,开始训练之前,出现了runner.register_training_hooks这个函数的调用:

 runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

5、runner.register_training_hooks这个函数体内分别注册多个hook,具体实现如下:

def register_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None,
                                momentum_config=None,
                                timer_config=dict(type='IterTimerHook'),
                                custom_hooks_config=None):
        self.register_lr_hook(lr_config)
        self.register_momentum_hook(momentum_config)
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
        self.register_timer_hook(timer_config)
        self.register_logger_hooks(log_config)
        self.register_custom_hooks(custom_hooks_config)

6、我们再以第一个函数self.register_lr_hook为例,看一下具体注册过程,主要步骤就是先利用配置文件和HOOKS去实例化这个类,然后

    def register_lr_hook(self, lr_config):
        if lr_config is None:
            return
        elif isinstance(lr_config, dict):
            assert 'policy' in lr_config
            policy_type = lr_config.pop('policy')
  			#首字母转为大写
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
            #我的例程中hook_type= 'StepLrUpdaterHook'
            lr_config['type'] = hook_type
            #HOOKS = Registry('hook')是之前注册好的,在 mmcv.build_from_cfg函数中,通过类名获得类,
            #然后再用lr_config里面的参数对类进行实例化。class StepLrUpdaterHook(LrUpdaterHook):
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
        #把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
        self.register_hook(hook, priority='VERY_HIGH')

7、下面一段伪代码build_from_cfg,展示了类实例化的主要过程。

def build_from_cfg(cfg, registry, default_args=None):
    args = cfg.copy()
    obj_type = args.pop('type')
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)

8、把实例好的类插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中

 def register_hook(self, hook, priority='NORMAL'):
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

9、call_hook,钩子函数的调用

self.call_hook('before_train_epoch')

#遍历 self._hooks列表中的类(这些类是按照优先级排序过的),按照fn_name获取hook类中对应的函数,再执行,class Hook是这些类的父类,这些函数在父类中声明了,但并没有实现。包括before_run,after_run,before_epoch,after_epoch,before_iter,after_iter,before_train_epoch,before_val_epoch,after_train_epoch,after_val_epoch等等…

    def call_hook(self, fn_name):
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

总结:在 python 中由于函数是一等公民,实现 hook 机制其实只需要传入一个函数即可。在mmdetection源码中,我们通过call_hook函数,实际上传入多个提前注册好的函数,实现了模型参数保存、学习率调度、梯度反向传播加上参数更新等功能。
1、实例化HOOKS
2、装饰器函数,将类名和类添加到HOOKS._module_dict
3、配置参数去实例化Hook子类
4、把实例好的Hook子类按照优先级顺序插入到当前类也就是class BaseRunner(metaclass=ABCMeta)的self._hook中
5、调用钩子函数的,根据函数名去self._hook查找调用

你可能感兴趣的:(mmdtection,python,目标检测,深度学习)