mmdetection Runner

目录

注册钩子类

预定义位点

训练流程中基本hook

Runner调用Hook类中的接口函数控制训练/测试过程

参考


Runner作用,通过注册的hook在预定义的位点执行自定义函数实现自定义训练流程。

注册钩子类

注册训练hooks,其原理是将LrUpdaterHook、OptimizerHook、CheckpointHook、IterTimerHook等钩子类注册到hooks中,实则存储在Runner类的self._hooks列表中。

    def register_training_hooks(self,
                                lr_config,
                                optimizer_config=None,
                                checkpoint_config=None,
                                log_config=None):
        """Register default hooks for training.

        Default hooks include:

        - LrUpdaterHook
        - OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
        - LoggerHook(s)
        """
        if optimizer_config is None:
            optimizer_config = {}
        if checkpoint_config is None:
            checkpoint_config = {}
        self.register_lr_hooks(lr_config)
        self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
        self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
        self.register_hook(IterTimerHook())
        if log_config is not None:
            self.register_logger_hooks(log_config)

预定义位点

训练流程中预定义位点包括:

训练开始前和结束后、epoch开始前和结束后、iteration开始前和结束后。如果包含测试过程,则在完成训练后,立即进入测试流程。

mmdetection Runner_第1张图片

训练流程中基本hook

训练流程中基本hook

训练位点 钩子类

分布式钩子类

训练前(before_run) LrUpdaterHook 主要负责Lr更新 Fp16OptimizerHook

1、 StepLrUpdaterHook:继承于LrUpdaterHook,主要用于指定学习策略,进而更新Lr。等间隔调整学习率,让学习率在达到一定iteration或epoch后,学习率更新为原来的  gamma倍。

2、CosineLrUpdaterHook:继承于LrUpdaterHook,余弦退火学习策略。余弦退火让学习率随epoch或iteration的变化类似于cosine,更新策略公式:

  表示学习率最小值,默认0, 当前epoch或iter, 表示学习率基准(也可理解为最大学习率)。

3、ExpLrUpdaterHook:继承于LrUpdaterHook,指数衰减调整学习率。

4、其他

epoch前(before_train_epoch) LrUpdaterHook DistSamplerSeedHook
iteration前(before_train_iter) LrUpdaterHook IterTimerHook
iteration后(after _train_ iter) OptimizerHook OptimizerHook:主要负责梯度清零、反向传播、参数更新及梯度裁剪工作。 IterTimerHook LoggerHook DistOptimizerHook
epoch后(after_train_epoch) CheckpointHook CheckpointHook: 主要负责训练一定epoch后,存储模型。 DistEvalHook
训练后(after_run) CustomHook

 注:用户也可以自定义hook以拓展训练中的行为

Runner调用Hook类中的接口函数控制训练/测试过程

mmdetection Runner_第2张图片

        每个固定的点位都有一个对应的函数接口,如epoch开始前,epoch开始后after_epoch,在每个固定的位点,Runner会调用self._hooks中每个hook的对应函数,如每次iteration开始前,会调用每个hook中的before_iter函数。

mmdetection Runner_第3张图片

        例如,epoch开始前,其对应的函数接口为before_epoch,也就是在继承了Hook类的钩子类重写了before_epoch接口函数,例如LrUpdaterHook钩子类重写了before_run、before_train_epoch、before_train_iter三个函数接口,在训练的不同阶段会调用者三个接口函数去更新学习率。

参考

学术|OpenMMLab开源工具使用教学(一)https://www.bilibili.com/video/BV1sy4y1v7zg/?spm_id_from=333.788.recommend_more_video.-1

学术|OpenMMLab开源工具使用教学(二)_哔哩哔哩_bilibili学术|OpenMMLab开源工具使用教学(二)https://www.bilibili.com/video/BV1Xv411W7Dj/?spm_id_from=333.788.recommend_more_video.-1学术|OpenMMLab开源工具使用教学(三)_哔哩哔哩_bilibili学术|OpenMMLab开源工具使用教学(三)https://www.bilibili.com/video/BV1dT4y1A7WN/?spm_id_from=333.788.recommend_more_video.1

你可能感兴趣的:(mmdetection,pytorch,深度学习,pytorch,人工智能,计算机视觉)