近期参与到了手写AI的车道线检测的学习中去,以此系列笔记记录学习与思考的全过程。车道线检测系列会持续更新,力求完整精炼,引人启示。所需前期知识,可以结合手写AI进行系统的学习。
日志系统
钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),**当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。**钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。
HIGHEST (0)
VERY_HIGH (10)
HIGH (30)
ABOVE_NORMAL (40)
NORMAL (50)
BELOW_NORMAL (60)
LOW (70)
VERY_LOW (90)
LOWEST (100)
名称 | 用途 | 优先级 |
---|---|---|
RuntimeInfoHook | 往 message hub 更新运行时信息 | VERY_HIGH (10) |
IterTimerHook | 统计迭代耗时 | NORMAL (50) |
DistSamplerSeedHook | 确保分布式 Sampler 的 shuffle 生效 | NORMAL (50) |
LoggerHook | 打印日志 | BELOW_NORMAL (60) |
ParamSchedulerHook | 调用 ParamScheduler 的 step 方法 | LOW (70) |
CheckpointHook | 按指定间隔保存权重 | VERY_LOW (90) |
名称 | 用途 | 优先级 |
---|---|---|
EMAHook | 模型参数指数滑动平均 | NORMAL (50) |
EmptyCacheHook | PyTorch CUDA 缓存清理 | NORMAL (50) |
SyncBuffersHook | 同步模型的 buffer | NORMAL (50) |
两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 default_hooks 参数,自定义钩子的配置传给 custom_hooks 参数,如下所示:
from mmengine.runner import Runner
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
logger=dict(type='LoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
)
custom_hooks = [dict(type='EmptyCacheHook')]
runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
runner.train()
CheckpointHook 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。CheckpointHook 的主要功能如下:
按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
保存最新的多个权重
保存最优权重
指定保存权重的路径
制作发布用的权重
设置开始保存权重的 epoch 数或者 iteration 数
下面介绍上面提到的 6 个功能。
假设我们一共训练 20 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。
# by_epoch 的默认值为 True
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))
如果想以迭代次数作为保存间隔,则可以将 by_epoch 设为 False,interval=5 则表示每迭代 5 次保存一次权重。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
如果只想保存一定数量的权重,可以通过设置 max_keep_ckpts 参数实现最多保存 max_keep_ckpts 个权重,当保存的权重数超过 max_keep_ckpts 时,前面的权重会被删除。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))
上述例子表示,假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。
如果想要保存训练过程验证集的最优权重,可以设置 save_best 参数,如果设置为 ‘auto’,则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))
也可以直接指定 save_best 的值为评价指标,例如在分类任务中,可以指定为 save_best=‘top-1’,则会根据 ‘top-1’ 的值判断当前权重是否最优。
除了 save_best 参数,和保存最优权重相关的参数还有 rule,greater_keys 和 less_keys,这三者用来判断 save_best 的值是越大越好还是越小越好。例如指定了 save_best=‘top-1’,可以指定 rule=‘greater’,则表示该值越大表示权重越好。
权重默认保存在工作目录(work_dir),但可以通过设置 out_dir 改变保存路径。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
如果你想在训练结束后自动生成可发布的权重(删除不需要的权重,例如优化器状态),你可以设置 published_keys 参数,选择需要保留的信息。注意:需要相应设置 save_best 或者 save_last 参数,这样才会生成可发布的权重,其中设置 save_best 会生成最优权重的可发布权重,设置 save_last 会生成最后一个权重的可发布权重,这两个参数也可同时设置。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
如果想要设置控制开始保存权重的 epoch 数或者 iteration 数,可以设置 save_begin 参数,默认为 0,表示从训练开始就保存权重。例如,如果总共训练 10 个 epoch,并且 save_begin 设置为 5,则将保存第 5、6、7、8、9 和 10 个 epoch 的权重。如果 interval=2,则仅保存第 5、7 和 9 个 epoch 的权重。
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=2, save_begin=5))
LoggerHook负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:
default_hooks = dict(logger=dict(type='LoggerHook', interval=20))
log_processor = dict(window_size=10, by_epoch=True, custom_cfg=None, num_digits=4)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel
train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)
class ToyModel(BaseModel):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, img, label, mode):
feat = self.linear(img)
loss1 = (feat - label).pow(2)
loss2 = (feat - label).abs()
return dict(loss1=loss1, loss2=loss2)
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01))
)
runner.train()
终端显示内容
08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][10/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0019 data_time: 0.0004 loss1: 0.8381 loss2: 0.9007 loss: 1.7388
08/21 02:58:41 - mmengine - INFO - Epoch(train) [1][20/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0029 data_time: 0.0010 loss1: 0.1978 loss2: 0.4312 loss: 0.6290
日志前缀:
Epoch 模式(by_epoch=True):Epoch(train) [{当前epoch次数}][{当前迭代次数}/{Dataloader 总长度}]
Iter 模式(by_epoch=False): Iter(train) [{当前迭代次数}/{总迭代次数}]
学习率(lr):统计最近一次迭代,参数更新的学习率
时间
迭代时间(time):最近 window_size(日志处理器参数) 次迭代,处理一个 batch 数据(包括数据加载和模型前向推理)的平均时间
数据时间(data_time):最近 window_size 次迭代,加载一个 batch 数据的平均时间
剩余时间(eta):根据总迭代次数和历次迭代时间计算出来的总剩余时间,剩余时间随着迭代次数增加逐渐趋于稳定
损失:模型前向推理得到的各种字段的损失,默认统计最近 window_size 次迭代的平均损失。
默认情况下,window_size=10,日志处理器会统计最近 10 次迭代,损失、迭代时间、数据时间的均值。
默认情况下,所有日志的有效位数(num_digits 参数)为 4。
默认情况下,输出所有自定义日志最近一次迭代的值。
基于上述规则,代码示例中的日志处理器会输出 loss1 和 loss2 每 10 次迭代的均值。如果我们想统计 loss1 从第一次迭代开始至今的全局均值,可以这样配置:
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
log_processor=dict( # 配置日志处理器
custom_cfg=[
dict(data_src='loss1', # 原日志名:loss1
method_name='mean', # 统计方法:均值统计
window_size='global')]) # 统计窗口:全局
)
runner.train()
08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][10/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0026 data_time: 0.0007 loss1: 0.7381 loss2: 0.8446 loss: 1.5827
08/21 02:58:49 - mmengine - INFO - Epoch(train) [1][20/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0030 data_time: 0.0012 loss1: 0.4521 loss2: 0.3939 loss: 0.5600
log_processor 默认输出 by_epoch=True 格式的日志。日志格式需要和 train_cfg 中的 by_epoch 参数保持一致,例如我们想按迭代次数输出日志,就需要另 log_processor 和 train_cfg 的 by_epoch=False。
其中 data_src 为原日志名,mean 为统计方法,global 为统计方法的参数。这样的话,日志中统计的 loss1 就是全局均值。我们可以在日志处理器中配置以下统计方法:
统计方法 | 参数 | 功能 |
---|---|---|
mean | window_size | 统计窗口内日志的均值 |
min | window_size | 统计窗口内日志的最小值 |
max | window_size | 统计窗口内日志的最大值 |
current | / | 返回最近一次更新的日志 |
其中window_size的值可以是:
数字:表示统计窗口的大小
global:统计全局的最大、最小和均值
epoch:统计一个 epoch 内的最大、最小和均值
如果我们既想统计窗口为 10 的 loss1 的局部均值,又想统计 loss1 的全局均值,则需要额外指定 log_name:
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
log_processor=dict(
custom_cfg=[
# log_name 表示 loss1 重新统计后的日志名
dict(data_src='loss1', log_name='loss1_global', method_name='mean', window_size='global')])
)
runner.train()
08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][10/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0016 data_time: 0.0004 loss1: 0.1512 loss2: 0.3751 loss: 0.5264 loss1_global: 0.1512
08/21 18:39:32 - mmengine - INFO - Epoch(train) [1][20/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0051 data_time: 0.0036 loss1: 0.0113 loss2: 0.0856 loss: 0.0970 loss1_global: 0.0813
类似地,我们也可以统计 loss1 的局部最大值和全局最大值:
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
log_processor=dict(custom_cfg=[
# 统计 loss1 的局部最大值,统计窗口为 10,并在日志中重命名为 loss1_local_max
dict(data_src='loss1',
log_name='loss1_local_max',
window_size=10,
method_name='max'),
# 统计 loss1 的全局最大值,并在日志中重命名为 loss1_local_max
dict(
data_src='loss1',
log_name='loss1_global_max',
method_name='max',
window_size='global')
]))
runner.train()
08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][10/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0021 data_time: 0.0006 loss1: 1.8495 loss2: 1.3427 loss: 3.1922 loss1_local_max: 2.8872 loss1_global_max: 2.8872
08/21 03:17:26 - mmengine - INFO - Epoch(train) [1][20/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0024 data_time: 0.0010 loss1: 0.5464 loss2: 0.7251 loss: 1.2715 loss1_local_max: 2.8872 loss1_global_max: 2.8872
from mmengine.logging import MessageHub
class ToyModel(BaseModel):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, img, label, mode):
feat = self.linear(img)
loss_tmp = (feat - label).abs()
loss = loss_tmp.pow(2)
message_hub = MessageHub.get_current_instance()
# 在日志中额外统计 `loss_tmp`
message_hub.update_scalar('train/loss_tmp', loss_tmp.sum())
return dict(loss=loss)
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
log_processor=dict(
custom_cfg=[
# 统计 loss_tmp 的局部均值
dict(
data_src='loss_tmp',
window_size=10,
method_name='mean')
]
)
)
runner.train()
08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][10/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0026 data_time: 0.0008 loss_tmp: 0.0097 loss: 0.0000
08/21 03:40:31 - mmengine - INFO - Epoch(train) [1][20/25] lr: 1.0000e-02 eta: 0:00:00 time: 0.0028 data_time: 0.0013 loss_tmp: 0.0065 loss: 0.0000
通过调用消息枢纽的接口实现自定义日志的统计,具体步骤如下:
调用 get_current_instance 接口获取执行器的消息枢纽。
调用 update_scalar 接口更新日志内容,其中第一个参数为日志的名称,日志名称以 train/,val/,test/ 前缀打头,用于区分训练状态,然后才是实际的日志名,如上例中的 train/loss_tmp,这样统计的日志中就会出现 loss_tmp。
配置日志处理器,以均值的方式统计 loss_tmp。如果不配置,日志里显示 loss_tmp 最近一次更新的值。
runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
log_level='DEBUG',
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)))
runner.train()
08/21 18:16:22 - mmengine - DEBUG - Get class `LocalVisBackend` from "vis_backend" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `LocalVisBackend` instance is built from registry, its implementation can be found in mmengine.visualization.vis_backend
08/21 18:16:22 - mmengine - DEBUG - Get class `RuntimeInfoHook` from "hook" registry in "mmengine"
08/21 18:16:22 - mmengine - DEBUG - An `RuntimeInfoHook` instance is built from registry, its implementation can be found in mmengine.hooks.runtime_info_hook
08/21 18:16:22 - mmengine - DEBUG - Get class `IterTimerHook` from "hook" registry in "mmengine"
...
ParamSchedulerHook 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。ParamSchedulerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
IterTimerHook 用于记录加载数据的时间以及迭代一次耗费的时间。IterTimerHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
DistSamplerSeedHook 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。DistSamplerSeedHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
RuntimeInfoHook 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,以便其他无法访问执行器的模块能够获取到这些信息。RuntimeInfoHook 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
EMAHook 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。
custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
EMAHook 默认使用 ExponentialMovingAverage,可选值还有 StochasticWeightAverage 和 MomentumAnnealingEMA。可以通过设置 ema_type 使用其他的平均策略。
custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]
EmptyCacheHook 调用 torch.cuda.empty_cache() 释放未被使用的显存。可以通过设置 before_epoch, after_iter 以及 after_epoch 参数控制释显存的时机,第一个参数表示在每个 epoch 开始之前,第二参数表示在每次迭代之后,第三个参数表示在每个 epoch 之后。
# 每一个 epoch 结束都会执行释放操作
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
SyncBuffersHook 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 running_mean 以及 running_var。
custom_hooks = [dict(type='SyncBuffersHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 after_train_iter 位点。
import torch
from mmengine.registry import HOOKS
from mmengine.hooks import Hook
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Defaults to 50.
"""
def __init__(self, interval=50):
self.interval = interval
def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""All subclasses should override this method, if they need any
operations after each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']),\
runner.logger.info('loss become infinite or NaN!')
我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子
from mmengine.runner import Runner
custom_hooks = [
dict(type='CheckInvalidLossHook', interval=50)
]
runner = Runner(custom_hooks=custom_hooks, ...) # 实例化执行器,主要完成环境的初始化以及各种模块的构建
runner.train() # 执行器开始训练
便会在每次模型前向计算后检查损失值。
注意,自定义钩子的优先级默认为 NORMAL (50),如果想改变钩子的优先级,则可以在配置中设置 priority 字段。
custom_hooks = [
dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
]
也可以在定义类时给定优先级
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
priority = 'ABOVE_NORMAL'
pre_hooks = [(print, 'hello')]
post_hooks = [(print, 'goodbye')]
def main():
for func, arg in pre_hooks:
func(arg)
print('do something here')
for func, arg in post_hooks:
func(arg)
main()
hello
do something here
goodbye
在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 register_forward_hook 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。
下面是 register_forward_hook 用法的简单示例:
import torch
import torch.nn as nn
def forward_hook_fn(
module, # 被注册钩子的对象
input, # module 前向计算的输入
output, # module 前向计算的输出
):
print(f'"forward_hook_fn" is invoked by {module.name}')
print('weight:', module.weight.data)
print('bias:', module.bias.data)
print('input:', input)
print('output:', output)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x):
y = self.fc(x)
return y
model = Model()
# 将 forward_hook_fn 注册到 model 每个子模块
for module in model.children():
module.register_forward_hook(forward_hook_fn)
x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)
"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
weight: tensor([[-0.4077, 0.0119, -0.3606]])
bias: tensor([-0.2943])
input: (tensor([[0., 1., 2.]]),)
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)
可以看到注册到 Linear 模块的 forward_hook_fn 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。
在介绍 MMEngine 中钩子的设计之前,先简单介绍使用 PyTorch 实现模型训练的基本步骤
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
pass
class Net(nn.Module):
pass
def main():
transform = transforms.ToTensor()
train_dataset = CustomDataset(transform=transform, ...)
val_dataset = CustomDataset(transform=transform, ...)
test_dataset = CustomDataset(transform=transform, ...)
train_dataloader = DataLoader(train_dataset, ...)
val_dataloader = DataLoader(val_dataset, ...)
test_dataloader = DataLoader(test_dataset, ...)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i in range(max_epochs):
for inputs, labels in train_dataloader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
with torch.no_grad():
for inputs, labels in val_dataloader:
outputs = net(inputs)
loss = criterion(outputs, labels)
with torch.no_grad():
for inputs, labels in test_dataloader:
outputs = net(inputs)
accuracy = ...
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 main 函数。为了提高 main 函数的灵活性和拓展性,我们可以在 main 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
def main():
...
call_hooks('before_run', hooks) # 任务开始前执行的逻辑
call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
call_hooks('before_train', hooks) # 训练开始前执行的逻辑
for i in range(max_epochs):
call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
for inputs, labels in train_dataloader:
call_hooks('before_train_iter', hooks) # 模型前向计算前执行的逻辑
outputs = net(inputs)
loss = criterion(outputs, labels)
call_hooks('after_train_iter', hooks) # 模型前向计算后执行的逻辑
loss.backward()
optimizer.step()
call_hooks('after_train_epoch', hooks) # 遍历完训练数据集后执行的逻辑
call_hooks('before_val_epoch', hooks) # 遍历验证数据集前执行的逻辑
with torch.no_grad():
for inputs, labels in val_dataloader:
call_hooks('before_val_iter', hooks) # 模型前向计算前执行
outputs = net(inputs)
loss = criterion(outputs, labels)
call_hooks('after_val_iter', hooks) # 模型前向计算后执行
call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
call_hooks('after_train', hooks) # 训练结束后执行的逻辑
call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
with torch.no_grad():
for inputs, labels in test_dataloader:
call_hooks('before_test_iter', hooks) # 模型前向计算后执行的逻辑
outputs = net(inputs)
accuracy = ...
call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
call_hooks('after_run', hooks) # 任务结束后执行的逻辑
为了方便管理,MMEngine 将位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
before_run
after_run
before_train
after_train
before_train_epoch
after_train_epoch
before_train_iter
after_train_iter
before_val
after_val
before_val_epoch
after_val_epoch
before_val_iter
after_val_iter
before_test
after_test
before_test_epoch
after_test_epoch
before_test_iter
after_test_iter
before_save_checkpoint
after_load_checkpoint
pre_hooks = [(print, 'hello')]
post_hooks = [(print, 'goodbye')]
def main():
for func, arg in pre_hooks:
func(arg)
print('do something here')
for func, arg in post_hooks:
func(arg)
main()
hello
do something here
goodbye
在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 register_forward_hook 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。
import torch
import torch.nn as nn
def forward_hook_fn(
module, # 被注册钩子的对象
input, # module 前向计算的输入
output, # module 前向计算的输出
):
print(f'"forward_hook_fn" is invoked by {module.name}')
print('weight:', module.weight.data)
print('bias:', module.bias.data)
print('input:', input)
print('output:', output)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x):
y = self.fc(x)
return y
model = Model()
# 将 forward_hook_fn 注册到 model 每个子模块
for module in model.children():
module.register_forward_hook(forward_hook_fn)
x = torch.Tensor([[0.0, 1.0, 2.0]])
y = model(x)
"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
weight: tensor([[-0.4077, 0.0119, -0.3606]])
bias: tensor([-0.2943])
input: (tensor([[0., 1., 2.]]),)
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)
可以看到注册到 Linear 模块的 forward_hook_fn 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。更多关于 PyTorch 钩子的用法可以阅读 nn.Module。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
pass
class Net(nn.Module):
pass
def main():
transform = transforms.ToTensor()
train_dataset = CustomDataset(transform=transform, ...)
val_dataset = CustomDataset(transform=transform, ...)
test_dataset = CustomDataset(transform=transform, ...)
train_dataloader = DataLoader(train_dataset, ...)
val_dataloader = DataLoader(val_dataset, ...)
test_dataloader = DataLoader(test_dataset, ...)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i in range(max_epochs):
for inputs, labels in train_dataloader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
with torch.no_grad():
for inputs, labels in val_dataloader:
outputs = net(inputs)
loss = criterion(outputs, labels)
with torch.no_grad():
for inputs, labels in test_dataloader:
outputs = net(inputs)
accuracy = ...
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 main 函数。为了提高 main 函数的灵活性和拓展性,我们可以在 main 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
def main():
...
call_hooks('before_run', hooks) # 任务开始前执行的逻辑
call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
call_hooks('before_train', hooks) # 训练开始前执行的逻辑
for i in range(max_epochs):
call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
for inputs, labels in train_dataloader:
call_hooks('before_train_iter', hooks) # 模型前向计算前执行的逻辑
outputs = net(inputs)
loss = criterion(outputs, labels)
call_hooks('after_train_iter', hooks) # 模型前向计算后执行的逻辑
loss.backward()
optimizer.step()
call_hooks('after_train_epoch', hooks) # 遍历完训练数据集后执行的逻辑
call_hooks('before_val_epoch', hooks) # 遍历验证数据集前执行的逻辑
with torch.no_grad():
for inputs, labels in val_dataloader:
call_hooks('before_val_iter', hooks) # 模型前向计算前执行
outputs = net(inputs)
loss = criterion(outputs, labels)
call_hooks('after_val_iter', hooks) # 模型前向计算后执行
call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
call_hooks('after_train', hooks) # 训练结束后执行的逻辑
call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
with torch.no_grad():
for inputs, labels in test_dataloader:
call_hooks('before_test_iter', hooks) # 模型前向计算后执行的逻辑
outputs = net(inputs)
accuracy = ...
call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
call_hooks('after_run', hooks) # 任务结束后执行的逻辑
为了方便管理,MMEngine 将位点定义为方法并集成到钩子基类(Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
before_run | after_run | before_train | after_train |
before_train_epoch | after_train | before_train_iter | after_train_iter |
before_val | after_val | before_val_epoch | after_val_epoch |
before_val_iter | after_val | before_test | after_test |
before_test_epoch | after_test | before_test_iter | after_test_iter |
before_save_checkpoint | after_load_checkpoint |