OpenMMLap之Hook机制详解

HOOK
HOOK机制在OpenMMLab系列框架中应用广泛,结合Runner类可以实现训练过程中的整个生命周期的管理。例如调整学习率,保存模型,优化器等
通过register的形式诸如Runner中实现丰富的扩展功能。接下来我们以train工作流为例分析调用位置及机制。
1.调用位置
以EpochBasedRunner(BaseRunner)为例分析

mmcv/runner/epoch_base_runner.py

class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.
    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        ######1.
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            ######2.
            self.call_hook('before_train_iter') 
            self.run_iter(data_batch, train_mode=True, **kwargs)
            ######3.
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

        ######4.
        self.call_hook('after_train_epoch')  
        self._epoch += 1

 mmcv/runner/base_runner.py

class BaseRunner(metaclass=ABCMeta):
    def call_hook(self, fn_name: str) -> None:
        """Call all hooks.
        Args:
            fn_name (str): The function name in each hook to be called, such as
             "before_train_epoch".
        """
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

观察代码我们可以发现,在训练的整个生命周期,有四个时间可以引入hooks,分别是before_train_epoch, before_train_iter, after_train_iter, after_train_epoch. 为什么这么命名呢?

2.调用机制
在训练时,利用self.call_hook执行hooks具体的操作,以OptimizerHook为例。观察call_hook函数我们发现,利用for循环调用getattr,getattr的具体作用是通过fn_name来获得属性值或getattr(hook, name)或调用同名函数getattr(hook, name)(),这里明显是后者的作用。

现在我们解释为什么这么命名,我们可以发现,不同的hooks类在定义的时候,其主函数体是根据上面的方式唯一命名的,例如optimizer.py中的after_train_iter函数,二者一一对应,也就是说,通过这种命名结合getattr操作可以实现对hooks操作的执行。

总的来说,这里先通过register机制将所有的hooks操作都加入self._hooks中,然后通过call_hooks中的getattr函数对self._hooks的hooks进行调用,通过命名来区分不同阶段该调用的hooks

熟悉getattr的同学可能会有疑问,既然每种hook都有唯一的成员函数与之对应,那么我循环遍历的时候,势必会出现当前函数在某一hook类没有定义的情况,例如,在执行self.call_hook('before_train_epoch') 的时候,OptimizerHook中没有before_train_epoch函数,那getattr不是会报错吗?
这个问题是个好问题,接下来解释原因,因为所有xxxHook都有一个父类Hook,在父类中定义了所有可能出现的方法,在子类中只需要重构需使用的函数即可,因此不会出现提到的问题,函数是存在的,只不过不执行具体操作而已。

mmcv/runner/hooks/optimizer.py

@HOOKS.register_module()
class OptimizerHook(Hook):
    """A hook contains custom operations for the optimizer.
    Args:
        grad_clip (dict, optional): A config dict to control the clip_grad.
            Default: None.
        detect_anomalous_params (bool): This option is only used for
            debugging which will slow down the training speed.
            Detect anomalous parameters that are not included in
            the computational graph with `loss` as the root.
            There are two cases
                - Parameters were not used during
                  forward pass.
                - Parameters were not used to produce
                  loss.
            Default: False.
    """

    def __init__(self,
                 grad_clip: Optional[dict] = None,
                 detect_anomalous_params: bool = False):
        self.grad_clip = grad_clip
        self.detect_anomalous_params = detect_anomalous_params

    def clip_grads(self, params):
        params = list(
            filter(lambda p: p.requires_grad and p.grad is not None, params))
        if len(params) > 0:
            return clip_grad.clip_grad_norm_(params, **self.grad_clip)

    def after_train_iter(self, runner):
        runner.optimizer.zero_grad()
        if self.detect_anomalous_params:
            self.detect_anomalous_parameters(runner.outputs['loss'], runner)
        runner.outputs['loss'].backward()

        if self.grad_clip is not None:
            grad_norm = self.clip_grads(runner.model.parameters())
            if grad_norm is not None:
                # Add grad norm to the logger
                runner.log_buffer.update({'grad_norm': float(grad_norm)},
                                         runner.outputs['num_samples'])
        runner.optimizer.step()

    def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
        logger = runner.logger
        parameters_in_graph = set()
        visited = set()

        def traverse(grad_fn):
            if grad_fn is None:
                return
            if grad_fn not in visited:
                visited.add(grad_fn)
                if hasattr(grad_fn, 'variable'):
                    parameters_in_graph.add(grad_fn.variable)
                parents = grad_fn.next_functions
                if parents is not None:
                    for parent in parents:
                        grad_fn = parent[0]
                        traverse(grad_fn)

        traverse(loss.grad_fn)
        for n, p in runner.model.named_parameters():
            if p not in parameters_in_graph and p.requires_grad:
                logger.log(
                    level=logging.ERROR,
                    msg=f'{n} with shape {p.size()} is not '
                    f'in the computational graph \n')


 

你可能感兴趣的:(功能代码积累,深度学习,人工智能)