以下链接是个人关于detectron2(目标检测框架),所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。
detectron2(目标检测框架)无死角玩转-00:目录
通过前面的博客,已经知道detectron2的整体架构,源码我们再回溯到detectron2/engine/train_loop.py,可以看到TrainerBase,我们来看看他存在那些子孙如下:
# 老祖宗 detectron2/engine/train_loop.py
class TrainerBase:
#第一代子孙 detectron2/engine/train_loop.py
class SimpleTrainer(TrainerBase):
# 第二代子孙 detectron2/engine/defaults.py
class DefaultTrainer(SimpleTrainer):
# 第三代子孙 tools/train_my.py-本人参考源码实现
class Trainer(DefaultTrainer):
我们先从老祖宗看起:
class TrainerBase:
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
self.storage.step()
def run_step(self):
raise NotImplementedError
我这里删减了很多注释,大家可以阅读一下源码的英文注释。总的来说,还是很简单的,首先需要实现如下函数:
# 已经实现
def after_train(self): def after_train(self): def before_step(self):
# 已定义,待子类实现
def run_step(self):
raise NotImplementedError
通过源码为我们可以知道,after_train,after_train,before_step他们的实现过程真的很简单,就是循环调用 self._hooks中对应的函数,那么self._hooks是什么东西呢?翻译过来为钩子!不急我们先放一放,其中的实现的
def register_hooks(self, hooks)
也放在后面一起讲解,我们先来看看他的第一代子孙class SimpleTrainer(TrainerBase):,其重写了def run_step(self),实现了
# 检测异常
def _detect_anomaly(self, losses, loss_dict):
# 简单的看作日志记录即可
def _write_metrics(self, metrics_dict: dict):
很明显,核心部分为def run_step(self),重写如下:
def run_step(self):
"""
Implement the standard training logic described above.
"""
# 确定为训练模式
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
# 获取一个batch_size的数据,如果有必要,是可以对dataloader进行装饰的
If your want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
# 如果有必要,可以重写loss的计算过程
If your want to do something with the losses, you can wrap the model.
"""
loss_dict = self.model(data)
losses = sum(loss for loss in loss_dict.values())
# 检测loss计算是否异常
self._detect_anomaly(losses, loss_dict)
# 写入log日志
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
"""
# 进行反向传播
If you need accumulate gradients or something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()
"""
# 一次迭代完成
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method.
"""
self.optimizer.step()
其实还是很好理解的,一路分析到这里,已经完成了反向传播。我们继续分析,看看其第二代子孙class DefaultTrainer(SimpleTrainer),路径为detectron2/engine/defaults.py,是有点复杂吧,不过关系不大,我们慢慢分析就好,再其初始化函数中我们又看到了
model = self.build_model(cfg) # 构建模型
optimizer = self.build_optimizer(cfg, model) # 构建优化方式
data_loader = self.build_train_loader(cfg) # 构建训练数据迭代器
很熟悉的,DefaultTrainer主要实现可如下函数:
# 继续训练,或者重新加载模型
def resume_or_load(self, resume=True):
# 构建和训练相关的hooks
def build_hooks(self):
# 主要调用了父类的train
def train(self):
# 根据cfg构建网络模型
def build_model(cls, cfg):
# 构建SGD优化器
def build_optimizer(cls, cfg, model):
# 定义学习率衰减方式
def build_lr_scheduler(cls, cfg, optimizer):
# 构建训练数据迭代器
def build_train_loader(cls, cfg):
# 构建测试数据迭代器
def build_test_loader(cls, cfg, dataset_name):
# 用于训练过程中,进行验证,主意,这里为空,并没有实现
def build_evaluator(cls, cfg, dataset_name):
# 对数据进行测试
def test(cls, cfg, model, evaluators=None):
可以看到,第三代子孙的功能基本以及很完善了,也就是剩下
def build_evaluator(cls, cfg, dataset_name):
需要之类重写,除此之外,还有一个重点,那当然就是:
# 构建和训练相关的hooks
def build_hooks(self):
我们暂且先放一下,来看看第四代子孙,也就是本人仿写tools/train_my.py中的class Trainer(DefaultTrainer),实现了:
# 根据cfg配置,构建评估器
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
# 这个我就是抄过来的,暂时不知道给来做什么的
def test_with_TTA(cls, cfg, model):
到这里,我们把祖宗到第三代都稍微过了一一遍,现在,还有一个重点,那就是Hook了。
首先,我们第一次提到Hook,是在祖宗TrainerBase的初始化函数之中:
class TrainerBase:
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
self.storage.step()
从这里可以很明确的看到,self._hooks列表中,存在着很多hook,当调用before_train,after_train,before_step其会循环调用self._hooks列表中所有hook对应的函数,def register_hooks(self, hooks),就是把hook注册到self._hooks列表中,我们先来看:
class HookBase:
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
似乎没有什么好看的,定义了几个函数,但是都没有实现实际上的东西,那么我们在源码中查看一下,其在那些地方被调用了:
可以看到,在源码中,HookBase的子类还是非常多的,其都是在detectron2/engine/hooks.py中实现:
# 可以自定义回调函数
class CallbackHook(HookBase):
# 对训练过程中的时间进行记录,追踪
class IterationTimer(HookBase):
# 迭代之前和迭代之后周期性写入
class PeriodicWriter(HookBase):
# 周期性的进行检查
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
# 学习率调整策略,每次迭完之后,都进行,判断是否达到学习率改变条件
class LRScheduler(HookBase):
# 迭代到指定次数,则进行评估
class EvalHook(HookBase):
# 可以简单理解为BN的升级版本
class PreciseBN(HookBase):
这里,为大家做一个简单的介绍,如果后续使用到这些hook再做详细的介绍。其实这些hook是很有的一个点子,大家在做消融实验的时候可以使用到。
总的来说,我们可以创建各种各样的hook,只要该hook继承于HookBase,就能通过TrainerBase.register_hooks进行注册,每个hook可以实现一下几个函数:
def before_train(self):
def after_train(self):
def before_step(self):
def after_step(self):
在这里,我们拿class EvalHook(HookBase)来举一个例子,该类实现是为了对训练中的模型进行测试,一般来说,测试都在迭代一定次数之后,再进行验证,所以其重写了函数:
def after_step(self):
迭代达到指定次数后就会进行测试,其初始化函数如下:
def __init__(self, eval_period, eval_function):
self._period = eval_period
self._func = eval_function
其传入了两个参数,一个是验证周期,一个是验证(测试)函数。
到这里,对于整体的把控,又更近一步了,下小结我们就来看看数据的预处理过程,也就是训练数据的迭代器。