第十三章 MMDetection3D解析系列四_钩子(hook)(车道线感知)

一 前言

近期参与到了手写AI的车道线检测的学习中去,以此系列笔记记录学习与思考的全过程。车道线检测系列会持续更新,力求完整精炼,引人启示。所需前期知识,可以结合手写AI进行系统的学习。

二 介绍

日志系统
钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),**当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。**钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。

三 hook优先级

HIGHEST (0)
VERY_HIGH (10)
HIGH (30)
ABOVE_NORMAL (40)
NORMAL (50)
BELOW_NORMAL (60)
LOW (70)
VERY_LOW (90)
LOWEST (100)

四 各种hook

4.1 默认钩子

名称 用途 优先级
RuntimeInfoHook 往 message hub 更新运行时信息 VERY_HIGH (10)
IterTimerHook 统计迭代耗时 NORMAL (50)
DistSamplerSeedHook 确保分布式 Sampler 的 shuffle 生效 NORMAL (50)
LoggerHook 打印日志 BELOW_NORMAL (60)
ParamSchedulerHook 调用 ParamScheduler 的 step 方法 LOW (70)
CheckpointHook 按指定间隔保存权重 VERY_LOW (90)

4.2 自定义钩子

名称 用途 优先级
EMAHook 模型参数指数滑动平均 NORMAL (50)
EmptyCacheHook PyTorch CUDA 缓存清理 NORMAL (50)
SyncBuffersHook 同步模型的 buffer NORMAL (50)

4.3 用法简介

两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 default_hooks 参数,自定义钩子的配置传给 custom_hooks 参数,如下所示:

from mmengine.runner import Runner

default_hooks = dict(
    runtime_info=dict(type='RuntimeInfoHook'),
    timer=dict(type='IterTimerHook'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    logger=dict(type='LoggerHook'),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1),
)

custom_hooks = [dict(type='EmptyCacheHook')]

runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()

五 各种hook具体用法

5.1 CheckpointHook

CheckpointHook 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。CheckpointHook 的主要功能如下:

按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
保存最新的多个权重
保存最优权重
指定保存权重的路径
制作发布用的权重
设置开始保存权重的 epoch 数或者 iteration 数

下面介绍上面提到的 6 个功能。

5.1.1 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重

假设我们一共训练 20 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。

# by_epoch 的默认值为 True
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))

如果想以迭代次数作为保存间隔,则可以将 by_epoch 设为 False,interval=5 则表示每迭代 5 次保存一次权重。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))

5.1.2 保存最新的多个权重

如果只想保存一定数量的权重,可以通过设置 max_keep_ckpts 参数实现最多保存 max_keep_ckpts 个权重,当保存的权重数超过 max_keep_ckpts 时,前面的权重会被删除。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))

上述例子表示,假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。

5.1.3 保存最优权重

如果想要保存训练过程验证集的最优权重,可以设置 save_best 参数,如果设置为 ‘auto’,则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))

也可以直接指定 save_best 的值为评价指标,例如在分类任务中,可以指定为 save_best=‘top-1’,则会根据 ‘top-1’ 的值判断当前权重是否最优。

除了 save_best 参数,和保存最优权重相关的参数还有 rule,greater_keys 和 less_keys,这三者用来判断 save_best 的值是越大越好还是越小越好。例如指定了 save_best=‘top-1’,可以指定 rule=‘greater’,则表示该值越大表示权重越好。

5.1.4 指定保存权重的路径

权重默认保存在工作目录(work_dir),但可以通过设置 out_dir 改变保存路径。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))

5.1.5 制作发布用的权重

如果你想在训练结束后自动生成可发布的权重(删除不需要的权重,例如优化器状态),你可以设置 published_keys 参数,选择需要保留的信息。注意:需要相应设置 save_best 或者 save_last 参数,这样才会生成可发布的权重,其中设置 save_best 会生成最优权重的可发布权重,设置 save_last 会生成最后一个权重的可发布权重,这两个参数也可同时设置。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))

5.1.6 设置开始保存权重的 epoch 数或者 iteration 数

如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 save_begin 参数,默认为 0,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 save_begin 设置为 5,则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 interval=2,则仅保存第 5、7 和 9 个 epoch 的权重。

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))

5.2 LoggerHook

LoggerHook负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。

如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:

default_hooks = dict(logger=dict(type='LoggerHook', interval=20))
  1. 记录日志
    执行器(Runner)在运行过程中会产生很多日志,例如损失、迭代时间、学习率等。MMEngine 实现了一套灵活的日志系统让我们能够在配置执行器时,选择不同类型日志的统计方式;在代码的任意位置,新增需要被统计的日志。
    log_processor
log_processor = dict(window_size=10, by_epoch=True, custom_cfg=None, num_digits=4)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss1 = (feat - label).pow(2)
        loss2 = (feat - label).abs()
        return dict(loss1=loss1, loss2=loss2)

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01))
)
runner.train()

终端显示内容

08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0019  data_time: 0.0004  loss1: 0.8381  loss2: 0.9007  loss: 1.7388
08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0029  data_time: 0.0010  loss1: 0.1978  loss2: 0.4312  loss: 0.6290

日志前缀:
Epoch 模式(by_epoch=True):Epoch(train) [{当前epoch次数}][{当前迭代次数}/{Dataloader 总长度}]
Iter 模式(by_epoch=False): Iter(train) [{当前迭代次数}/{总迭代次数}]

学习率(lr):统计最近一次迭代,参数更新的学习率

时间
迭代时间(time):最近 window_size(日志处理器参数) 次迭代,处理一个 batch 数据(包括数据加载和模型前向推理)的平均时间
数据时间(data_time):最近 window_size 次迭代,加载一个 batch 数据的平均时间
剩余时间(eta):根据总迭代次数和历次迭代时间计算出来的总剩余时间,剩余时间随着迭代次数增加逐渐趋于稳定

损失:模型前向推理得到的各种字段的损失,默认统计最近 window_size 次迭代的平均损失。

默认情况下,window_size=10,日志处理器会统计最近 10 次迭代,损失、迭代时间、数据时间的均值。

默认情况下,所有日志的有效位数(num_digits 参数)为 4。

默认情况下,输出所有自定义日志最近一次迭代的值。

基于上述规则,代码示例中的日志处理器会输出 loss1 和 loss2 每 10 次迭代的均值。如果我们想统计 loss1 从第一次迭代开始至今的全局均值,可以这样配置:

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(  # 配置日志处理器
        custom_cfg=[
            dict(data_src='loss1',  # 原日志名:loss1
                 method_name='mean',  # 统计方法:均值统计
                 window_size='global')])  # 统计窗口:全局
)
runner.train()

08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0026  data_time: 0.0007  loss1: 0.7381  loss2: 0.8446  loss: 1.5827
08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0030  data_time: 0.0012  loss1: 0.4521  loss2: 0.3939  loss: 0.5600

log_processor 默认输出 by_epoch=True 格式的日志。日志格式需要和 train_cfg 中的 by_epoch 参数保持一致,例如我们想按迭代次数输出日志,就需要另 log_processor 和 train_cfg 的 by_epoch=False。

其中 data_src 为原日志名,mean 为统计方法,global 为统计方法的参数。这样的话,日志中统计的 loss1 就是全局均值。我们可以在日志处理器中配置以下统计方法:

统计方法 参数 功能
mean window_size 统计窗口内日志的均值
min window_size 统计窗口内日志的最小值
max window_size 统计窗口内日志的最大值
current / 返回最近一次更新的日志

其中window_size的值可以是:

数字:表示统计窗口的大小
global:统计全局的最大、最小和均值
epoch:统计一个 epoch 内的最大、最小和均值

如果我们既想统计窗口为 10 的 loss1 的局部均值,又想统计 loss1 的全局均值,则需要额外指定 log_name:

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(
        custom_cfg=[
            # log_name 表示 loss1 重新统计后的日志名
            dict(data_src='loss1', log_name='loss1_global', method_name='mean', window_size='global')])
)
runner.train()

08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0016  data_time: 0.0004  loss1: 0.1512  loss2: 0.3751  loss: 0.5264  loss1_global: 0.1512
08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0051  data_time: 0.0036  loss1: 0.0113  loss2: 0.0856  loss: 0.0970  loss1_global: 0.0813

类似地,我们也可以统计 loss1 的局部最大值和全局最大值:

runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(custom_cfg=[
        # 统计 loss1 的局部最大值,统计窗口为 10,并在日志中重命名为 loss1_local_max
        dict(data_src='loss1',
             log_name='loss1_local_max',
             window_size=10,
             method_name='max'),
        # 统计 loss1 的全局最大值,并在日志中重命名为 loss1_local_max
        dict(
            data_src='loss1',
            log_name='loss1_global_max',
            method_name='max',
            window_size='global')
    ]))
runner.train()

08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0021  data_time: 0.0006  loss1: 1.8495  loss2: 1.3427  loss: 3.1922  loss1_local_max: 2.8872  loss1_global_max: 2.8872
08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0024  data_time: 0.0010  loss1: 0.5464  loss2: 0.7251  loss: 1.2715  loss1_local_max: 2.8872  loss1_global_max: 2.8872
  1. 自定义统计内容
    除了 MMEngine 默认的日志统计类型,如损失、迭代时间、学习率,用户也可以自行添加日志的统计内容。例如我们想统计损失的中间结果,可以这样做:
from mmengine.logging import MessageHub


class ToyModel(BaseModel):

    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, img, label, mode):
        feat = self.linear(img)
        loss_tmp = (feat - label).abs()
        loss = loss_tmp.pow(2)

        message_hub = MessageHub.get_current_instance()
        # 在日志中额外统计 `loss_tmp`
        message_hub.update_scalar('train/loss_tmp', loss_tmp.sum())
        return dict(loss=loss)


runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
    log_processor=dict(
        custom_cfg=[
        # 统计 loss_tmp 的局部均值
            dict(
                data_src='loss_tmp',
                window_size=10,
                method_name='mean')
        ]
    )
)
runner.train()

08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][10/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0026  data_time: 0.0008  loss_tmp: 0.0097  loss: 0.0000
08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][20/25]  lr: 1.0000e-02  eta: 0:00:00  time: 0.0028  data_time: 0.0013  loss_tmp: 0.0065  loss: 0.0000

通过调用消息枢纽的接口实现自定义日志的统计,具体步骤如下:

调用 get_current_instance 接口获取执行器的消息枢纽。
调用 update_scalar 接口更新日志内容,其中第一个参数为日志的名称,日志名称以 train/,val/,test/ 前缀打头,用于区分训练状态,然后才是实际的日志名,如上例中的 train/loss_tmp,这样统计的日志中就会出现 loss_tmp。
配置日志处理器,以均值的方式统计 loss_tmp。如果不配置,日志里显示 loss_tmp 最近一次更新的值。

  1. 输出调试日志
    初始化执行器(Runner)时,将 log_level 设置成 debug。这样终端上就会额外输出日志等级为 debug 的日志
runner = Runner(
    model=ToyModel(),
    work_dir='tmp_dir',
    train_dataloader=train_dataloader,
    log_level='DEBUG',
    train_cfg=dict(by_epoch=True, max_epochs=1),
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()

08/21 18:16:22 - mmengine - DEBUG - Get class `LocalVisBackend` from "vis_backend" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `LocalVisBackend` instance is built from registry, its implementation can be found in mmengine.visualization.vis_backend
08/21 18:16:22 - mmengine - DEBUG - Get class `RuntimeInfoHook` from "hook" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `RuntimeInfoHook` instance is built from registry, its implementation can be found in mmengine.hooks.runtime_info_hook
08/21 18:16:22 - mmengine - DEBUG - Get class `IterTimerHook` from "hook" registry in "mmengine"
...

5.3 ParamSchedulerHook

ParamSchedulerHook 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。ParamSchedulerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

5.4 IterTimerHook

IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

5.5 DistSamplerSeedHook

DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

5.6 RuntimeInfoHook

RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。

5.7 EMAHook

EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。

custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverage 和 MomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。

custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]

5.8 EmptyCacheHook

EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,第一个参数表示在每个 epoch 开始之前,第二参数表示在每次迭代之后,第三个参数表示在每个 epoch 之后。

# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

5.9 SyncBuffersHoo

SyncBuffersHook 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 running_mean 以及 running_var。

custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()

六 自定义钩子

如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。

例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 after_train_iter 位点。

import torch

from mmengine.registry import HOOKS
from mmengine.hooks import Hook


@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
    """Check invalid loss hook.

    This hook will regularly check whether the loss is valid
    during training.

    Args:
        interval (int): Checking interval (every k iterations).
            Defaults to 50.
    """

    def __init__(self, interval=50):
        self.interval = interval

    def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
        """All subclasses should override this method, if they need any
        operations after each training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        if self.every_n_train_iters(runner, self.interval):
            assert torch.isfinite(outputs['loss']),\
                runner.logger.info('loss become infinite or NaN!')

我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子

from mmengine.runner import Runner

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50)
]
runner = Runner(custom_hooks=custom_hooks, ...)  # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train()  # 执行器开始训练

便会在每次模型前向计算后检查损失值。

注意,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段。

custom_hooks = [
    dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
]

也可以在定义类时给定优先级

@HOOKS.register_module()
class CheckInvalidLossHook(Hook):

    priority = 'ABOVE_NORMAL'

七 hook使用

pre_hooks = [(print, 'hello')]
post_hooks = [(print, 'goodbye')]

def main():
    for func, arg in pre_hooks:
        func(arg)
    print('do something here')
    for func, arg in post_hooks:
        func(arg)

main()

hello
do something here
goodbye

在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 register_forward_hook 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。
下面是 register_forward_hook 用法的简单示例:

import torch
import torch.nn as nn

def forward_hook_fn(
    module,  # 被注册钩子的对象
    input,  # module 前向计算的输入
    output,  # module 前向计算的输出
):
    print(f'"forward_hook_fn" is invoked by {module.name}')
    print('weight:', module.weight.data)
    print('bias:', module.bias.data)
    print('input:', input)
    print('output:', output)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        y = self.fc(x)
        return y

model = Model()
# 将 forward_hook_fn 注册到 model 每个子模块
for module in model.children():
    module.register_forward_hook(forward_hook_fn)

x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)

"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
weight: tensor([[-0.4077,  0.0119, -0.3606]])
bias: tensor([-0.2943])
input: (tensor([[0., 1., 2.]]),)
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)

可以看到注册到 Linear 模块的 forward_hook_fn 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。

八 MMEngine 中钩子的设计

在介绍 MMEngine 中钩子的设计之前,先简单介绍使用 PyTorch 实现模型训练的基本步骤

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    pass

class Net(nn.Module):
    pass

def main():
    transform = transforms.ToTensor()
    train_dataset = CustomDataset(transform=transform, ...)
    val_dataset = CustomDataset(transform=transform, ...)
    test_dataset = CustomDataset(transform=transform, ...)
    train_dataloader = DataLoader(train_dataset, ...)
    val_dataloader = DataLoader(val_dataset, ...)
    test_dataloader = DataLoader(test_dataset, ...)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    for i in range(max_epochs):
        for inputs, labels in train_dataloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for inputs, labels in val_dataloader:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            outputs = net(inputs)
            accuracy = ...

上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 main 函数。为了提高 main 函数的灵活性和拓展性,我们可以在 main 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。

def main():
    ...
    call_hooks('before_run', hooks)  # 任务开始前执行的逻辑
    call_hooks('after_load_checkpoint', hooks)  # 加载权重后执行的逻辑
    call_hooks('before_train', hooks)  # 训练开始前执行的逻辑
    for i in range(max_epochs):
        call_hooks('before_train_epoch', hooks)  # 遍历训练数据集前执行的逻辑
        for inputs, labels in train_dataloader:
            call_hooks('before_train_iter', hooks)  # 模型前向计算前执行的逻辑
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            call_hooks('after_train_iter', hooks)  # 模型前向计算后执行的逻辑
            loss.backward()
            optimizer.step()
        call_hooks('after_train_epoch', hooks)  # 遍历完训练数据集后执行的逻辑

        call_hooks('before_val_epoch', hooks)  # 遍历验证数据集前执行的逻辑
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                call_hooks('before_val_iter', hooks)  # 模型前向计算前执行
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                call_hooks('after_val_iter', hooks)  # 模型前向计算后执行
        call_hooks('after_val_epoch', hooks)  # 遍历完验证数据集前执行

        call_hooks('before_save_checkpoint', hooks)  # 保存权重前执行的逻辑
    call_hooks('after_train', hooks)  # 训练结束后执行的逻辑

    call_hooks('before_test_epoch', hooks)  # 遍历测试数据集前执行的逻辑
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            call_hooks('before_test_iter', hooks)  # 模型前向计算后执行的逻辑
            outputs = net(inputs)
            accuracy = ...
            call_hooks('after_test_iter', hooks)  # 遍历完成测试数据集后执行的逻辑
    call_hooks('after_test_epoch', hooks)  # 遍历完测试数据集后执行

    call_hooks('after_run', hooks)  # 任务结束后执行的逻辑

为了方便管理,MMEngine 将位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。

before_run
after_run
before_train
after_train
before_train_epoch
after_train_epoch
before_train_iter
after_train_iter
before_val
after_val
before_val_epoch
after_val_epoch
before_val_iter
after_val_iter
before_test
after_test
before_test_epoch
after_test_epoch
before_test_iter
after_test_iter
before_save_checkpoint
after_load_checkpoint

九 具体点位的钩子

9.1 钩子示例

pre_hooks = [(print, 'hello')]
post_hooks = [(print, 'goodbye')]

def main():
    for func, arg in pre_hooks:
        func(arg)
    print('do something here')
    for func, arg in post_hooks:
        func(arg)
main()

hello
do something here
goodbye

在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 register_forward_hook 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。

import torch
import torch.nn as nn

def forward_hook_fn(
    module,  # 被注册钩子的对象
    input,  # module 前向计算的输入
    output,  # module 前向计算的输出
):
    print(f'"forward_hook_fn" is invoked by {module.name}')
    print('weight:', module.weight.data)
    print('bias:', module.bias.data)
    print('input:', input)
    print('output:', output)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 1)

    def forward(self, x):
        y = self.fc(x)
        return y

model = Model()
# 将 forward_hook_fn 注册到 model 每个子模块
for module in model.children():
    module.register_forward_hook(forward_hook_fn)

x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)
"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
weight: tensor([[-0.4077,  0.0119, -0.3606]])
bias: tensor([-0.2943])
input: (tensor([[0., 1., 2.]]),)
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)

可以看到注册到 Linear 模块的 forward_hook_fn 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。更多关于 PyTorch 钩子的用法可以阅读 nn.Module。

9.2 钩子的设计

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    pass

class Net(nn.Module):
    pass

def main():
    transform = transforms.ToTensor()
    train_dataset = CustomDataset(transform=transform, ...)
    val_dataset = CustomDataset(transform=transform, ...)
    test_dataset = CustomDataset(transform=transform, ...)
    train_dataloader = DataLoader(train_dataset, ...)
    val_dataloader = DataLoader(val_dataset, ...)
    test_dataloader = DataLoader(test_dataset, ...)

    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    for i in range(max_epochs):
        for inputs, labels in train_dataloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for inputs, labels in val_dataloader:
                outputs = net(inputs)
                loss = criterion(outputs, labels)

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            outputs = net(inputs)
            accuracy = ...

上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 main 函数。为了提高 main 函数的灵活性和拓展性,我们可以在 main 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。

def main():
    ...
    call_hooks('before_run', hooks)  # 任务开始前执行的逻辑
    call_hooks('after_load_checkpoint', hooks)  # 加载权重后执行的逻辑
    call_hooks('before_train', hooks)  # 训练开始前执行的逻辑
    for i in range(max_epochs):
        call_hooks('before_train_epoch', hooks)  # 遍历训练数据集前执行的逻辑
        for inputs, labels in train_dataloader:
            call_hooks('before_train_iter', hooks)  # 模型前向计算前执行的逻辑
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            call_hooks('after_train_iter', hooks)  # 模型前向计算后执行的逻辑
            loss.backward()
            optimizer.step()
        call_hooks('after_train_epoch', hooks)  # 遍历完训练数据集后执行的逻辑

        call_hooks('before_val_epoch', hooks)  # 遍历验证数据集前执行的逻辑
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                call_hooks('before_val_iter', hooks)  # 模型前向计算前执行
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                call_hooks('after_val_iter', hooks)  # 模型前向计算后执行
        call_hooks('after_val_epoch', hooks)  # 遍历完验证数据集前执行

        call_hooks('before_save_checkpoint', hooks)  # 保存权重前执行的逻辑
    call_hooks('after_train', hooks)  # 训练结束后执行的逻辑

    call_hooks('before_test_epoch', hooks)  # 遍历测试数据集前执行的逻辑
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            call_hooks('before_test_iter', hooks)  # 模型前向计算后执行的逻辑
            outputs = net(inputs)
            accuracy = ...
            call_hooks('after_test_iter', hooks)  # 遍历完成测试数据集后执行的逻辑
    call_hooks('after_test_epoch', hooks)  # 遍历完测试数据集后执行

    call_hooks('after_run', hooks)  # 任务结束后执行的逻辑

为了方便管理,MMEngine 将位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。

before_run after_run before_train after_train
before_train_epoch after_train before_train_iter after_train_iter
before_val after_val before_val_epoch after_val_epoch
before_val_iter after_val before_test after_test
before_test_epoch after_test before_test_iter after_test_iter
before_save_checkpoint after_load_checkpoint

你可能感兴趣的:(车道线检测,手写AI,深度学习,人工智能,深度学习)