深度理解目标检测(MMdetection)-HOOK机制

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自|计算机视觉联盟

最近做了一段时间的目标检测,不得不说检测这块还是相对比较复杂的,在熟悉项目的同时也确实学习到了很多有用的东西。MMdetetion是现在最著名、算法包最多并且使用人数最多的训练框架,其中的源码非常值得学习,今天总结下我对其中HOOK(钩子)机制的理解。

MMdetection最近更新很多,我以2.4.0版本的代码进行解读,分享自己的理解,也吸纳观众的点评。HOOK、Runer的定义在MMCV当中,MMdetection和MMCV是版本匹配的,我这里使用的是MMCV 1.1.2的代码。(HOOK相关的定义主要在MMCV中,下面用的代码都是摘自于MMCV)。

1.HOOK机制的作用

MMdetection中的HOOK可以理解为一种触发器,也可以理解为一种训练框架的架构规范,它规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。

首先看一下HOOK的基类定义

# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry

HOOKS = Registry('hook')


class Hook:

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

可以说基类函数中定义了许多我们在模型训练中需要用到的一些功能,如果想定义一些操作我们就可以继承这个类并定制化我们的功能,可以看到HOOK中每一个参数都是有runner作为参数传入的。关于Runner的作用下一篇文章接着说,简而言之,Runner是一个模型训练的工厂,在其中我们可以加载数据、训练、验证以及梯度backward等等全套流程。MMdetection在设计的时候也为runner传入丰富的参数,定义了一个非常好的训练范式。在你的每一个hook函数中,都可以对runner进行你想要的操作。

而HOOK是怎么嵌套进runner中的呢?其实是在Runner中定义了一个hook的list,list中的每一个元素就是一个实例化的HOOK对象。其中提供了两种注册hook的方法,register_hook是传入一个实例化的HOOK对象,并将它插入到一个列表中,register_hook_from_cfg是传入一个配置项,根据配置项来实例化HOOK对象并插入到列表中。当然第二种方法又是MMLab的开源生态中定义的一种基础方法mmcv.build_from_cfg了,无论在MMdetection还是其他MMLab开源的算法框架中,都遵循着MMCV的这套基于配置项实例化对象的方法。毕竟MMCV是提供了一个基础的功能,服务于各个算法框架,这也是为什么MMLab的代码高质量的原因。不仅仅是算法的复现,更是架构、编程范式的一种体现,真·代码如诗

def register_hook(self, hook, priority='NORMAL'):
        """Register a hook into the hook list.
        The hook will be inserted into a priority queue, with the specified
        priority (See :class:`Priority` for details of priorities).
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.
        Args:
            hook (:obj:`Hook`): The hook to be registered.
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
        priority = get_priority(priority)
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        # hook是分优先级插入到list中的,在MMdetection中不同的HOOK是有优先级的,为什么呢?稍后在hook的调用中解释哈
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.
        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.
        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

调用HOOK函数

def call_hook(self, fn_name):
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

可以看到HOOK是调用的时候是遍历List,然后根据HOOK的名字来调用。这也是为什么要区分优先级的原因,优先级越高的放在List的前面,这样就能更快地被调用。当你想用_before_run_epoch_来做A和B两件事情的时候,在runner里面就是调用一次self.before_run_epoch,但是先做A还是先做B,就是通过不同的HOOK的优先级来决定了。比如在evaluation的时候对需要做测试,但是测试前对参数做滑动平均。比如emaHOOK中的72行,也写明了要在测试之前做指数滑动平均。

def after_train_epoch(self, runner):
        """We load parameter values from ema backup to model before the
        EvalHook."""
        self._swap_ema_parameters()

checkpoint.py的HOOK中,同样也定义了after_train_epoch函数如下:

@master_only
    def after_train_epoch(self, runner):
        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
            return

        runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
        if not self.out_dir:
            self.out_dir = runner.work_dir
        runner.save_checkpoint(
            self.out_dir, save_optimizer=self.save_optimizer, **self.args)

        # remove other checkpoints
        if self.max_keep_ckpts > 0:
            filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
            current_epoch = runner.epoch + 1
            for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
                ckpt_path = os.path.join(self.out_dir,
                                         filename_tmpl.format(epoch))
                if os.path.exists(ckpt_path):
                    os.remove(ckpt_path)
                else:
                    break

从测试代码中可以看到不同的HOOK虽然都是重写了after_train_epoch函数,但是调用的顺序还是先调用ema.py中的,然后再调用checkpoint.py中的after_train_epoch

resume_ema_hook = EMAHook(
        momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
    runner = _build_demo_runner()
    runner.model = demo_model
    # 设置了HIGHREST的优先级
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 2)

具体的优先级定义有以下7种,作为HOOK的类成员属性。具体定义在链接中。

+------------+------------+
    | Level      | Value      |
    +============+============+
    | HIGHEST    | 0          |
    +------------+------------+
    | VERY_HIGH  | 10         |
    +------------+------------+
    | HIGH       | 30         |
    +------------+------------+
    | NORMAL     | 50         |
    +------------+------------+
    | LOW        | 70         |
    +------------+------------+
    | VERY_LOW   | 90         |
    +------------+------------+
    | LOWEST     | 100        |
    +------------+------------+

2.举一个简单的例子

最近打算好好锻炼身体,健康生活,努力工作,我打算让自己变得更加自律。我给自己定下了几个条例,每天吃早饭之前得晨练30分钟,运动完之后才会感觉充满活力。每天吃午饭之前我得跑上一个实验,吃完饭之后回来刚好可以看下中间结果,吃完午饭之后我感觉结果没问题我需要午休30分钟, 晚上下班前我如果没什么事再锻炼30分钟。秉承着这样的原则我给自己定义一个HOOK来规范我的生活。

  • 定义我的HOOK

import sys
class HOOK:

    def before_breakfast(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def after_breakfast(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def before_lunch(self, runner):
        print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name))

    def after_lunch(self, runner):
        print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name))

    def before_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_finish_work(self, runner, are_you_busy=False):
        if are_you_busy:
            print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name))
        else:
            print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name))
  • 定义我的Runner

class Runner(object):
    def __init__(self, ):
        pass
        self._hooks = []

    def register_hook(self, hook):
        # 这里不做优先级判断,直接在头部插入HOOK
        self._hooks.insert(0, hook)

    def call_hook(self, hook_name):
        for hook in self._hooks:
            getattr(hook, hook_name)(self)

    def run(self):
        print('开始启动我的一天')
        self.call_hook('before_breakfast')
        self.call_hook('after_breakfast')
        self.call_hook('before_lunch')
        self.call_hook('after_lunch')
        self.call_hook('before_dinner')
        self.call_hook('after_dinner')
        self.call_hook('after_finish_work')
        print('~~睡觉~~')
  • 运行main函数,注册HOOK并且调用Runner.run()开启我的一天

from MyHook import HOOK
from MyRunner import Runner
runner = Runner()
hook = HOOK()
runner.register_hook(hook)
runner.run()
  • 得到的输出结果如下:

开始启动我的一天
before_breakfast:吃早饭之前晨练30分钟
after_breakfast:吃早饭之前晨练30分钟
before_lunch:吃午饭之前跑上实验
after_lunch:吃完午饭午休30分钟
before_dinner: 没想好做什么
after_dinner: 没想好做什么
after_finish_work:今天没啥事,去锻炼30分钟
~~睡觉~~

3.总结

MMdetection中的HOOK设计巧妙,很好地对算法训练、测试进行了抽象和解耦。每一个做上层算法模型的,都值得一看。感谢MMLab贡献这么优质的代码,让我等凡夫俗子醍醐灌顶。

除了HOOK之外,这个代码中还有很多优质的思想。比如Runner是怎么做到包办一切的?注册器这个中枢管理系统是怎么工作的?多卡训练的一些坑是怎么解决的?等等等等,我也在持续地学习和消化。路漫漫其修远兮,吾将上下而求索。

一个小题目:我的代码中每个函数输出的时候都会打印出这个函数名,这个可以用_装饰器_很方便地解决奥。装饰器这个东西在MMLab的系列项目中有大量的应用。其中对fp16的支持让大家赞不绝口。接下来有时间,对Runner、Register、装饰器这些东西好好盘一盘。

end

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

深度理解目标检测(MMdetection)-HOOK机制_第1张图片

深度理解目标检测(MMdetection)-HOOK机制_第2张图片

你可能感兴趣的:(算法,python,人工智能,编程语言,java)