MMDetection —— Hook机制

MMOCR的hook定义在

/home/xhhao/anaconda3/envs/open-mmocr/lib/python3.7/site-packages/mmcv/runner/hooks/hook.py

这只是hook基类

MMDetection —— Hook机制_第1张图片

具体用的是哪个hook是在mmocr/apis/train.py

这是train的时候用的DistSamplerSeedHook

这是val的时候用的DistEvalHook

Hook机制规定了在算法训练过程中的种种操作,并且我们可以通过继承HOOK类,然后注册HOOK自定义我们想要的操作。

HOOK基类的定义

from mmcv.utils import Registry

HOOKS = Registry('hook')


class Hook:

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

在baserunner类中有register_hook函数,还有很多地方有

hook函数是有多种类型的

hook优先级

MMDetection —— Hook机制_第2张图片

MMDetection —— Hook机制_第3张图片

import sys
class HOOK:

    def before_breakfast(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def after_breakfast(self, runner):
        print('{}:吃早饭之前晨练30分钟'.format(sys._getframe().f_code.co_name))

    def before_lunch(self, runner):
        print('{}:吃午饭之前跑上实验'.format(sys._getframe().f_code.co_name))

    def after_lunch(self, runner):
        print('{}:吃完午饭午休30分钟'.format(sys._getframe().f_code.co_name))

    def before_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_dinner(self, runner):
        print('{}: 没想好做什么'.format(sys._getframe().f_code.co_name))

    def after_finish_work(self, runner, are_you_busy=False):
        if are_you_busy:
            print('{}:今天事贼多,还是加班吧'.format(sys._getframe().f_code.co_name))
        else:
            print('{}:今天没啥事,去锻炼30分钟'.format(sys._getframe().f_code.co_name))

class Runner(object):
    def __init__(self, ):
        pass
        self._hooks = []

    def register_hook(self, hook):
        # 这里不做优先级判断,直接在头部插入HOOK
        self._hooks.insert(0, hook) #将hook这个类插入到self._hook list中的第0个位置
        

    def call_hook(self, hook_name):
        for hook in self._hooks:
            getattr(hook, hook_name)(self)

    def run(self):
        print('开始启动我的一天')
        self.call_hook('before_breakfast')
        self.call_hook('after_breakfast')
        self.call_hook('before_lunch')
        self.call_hook('after_lunch')
        self.call_hook('before_dinner')
        self.call_hook('after_dinner')
        self.call_hook('after_finish_work')
        print('~~睡觉~~')



runner = Runner()
hook = HOOK()
runner.register_hook(hook)
runner.run()

runner中用到哪个hook要在main函数中给它注册

MMDetection —— Hook机制_第4张图片

你可能感兴趣的:(MMDetection)