行人重识别02-05:fast-reid(BoT)-pytorch编程规范(fast-reid为例)2-DefaultTrainer解析

以下链接是个人关于fast-reid(BoT行人重识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

行人重识别02-00:fast-reid(BoT)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

通过上一篇博客。我们已经知道继承于 HookBase 的类,都存在以下几个方法:

    def before_train(self):    # 在第一次迭代之前调用
    def after_train(self):      # 在最后一次迭代之后调用
    def before_step(self):   # 在每次迭代之前调用
    def after_step(self):      # 在每次迭代之后调用

并且已经知道他是在什么时候被带哦用,同时知道了训练的大致过程。但是hooks是如何创建的,我们需要那些hooks,我们不是很清楚,接下来我们会为大家进行讲解

DefaultTrainer

在 fastreid\engine\defaults.py 文件中,我们可以看到如下源码:

class DefaultTrainer(SimpleTrainer):
    """
    具有默认训练逻辑的培训师,继承于SimpleTrainer.主要包含了以下逻辑
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    # 根据配置文件创建optimizer, scheduler, dataloader
    1. Create model, optimizer, scheduler, dataloader from the given config.

    # 如果指定了模型权重文件,则加载模型权重
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.

    # 注册一些通用的hooks
    3. Register a few common hooks.

    #这是一个标准的简单训练模型流程,可以减少只需要标准培训工作流程的用户的代码样板,
    这意味着这门课对你的训练逻辑做了很多假设,这些假设在新的研究中很容易变得无效
    事实上,任何超出班级:‘SimpleTrainer’太多了,不适合研究。这个类的代码已经注释了它所产生的限制性假设
    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.
    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:
    # 覆盖类方法
    1. Overwrite methods of this class, OR:
    # 用法:class:`SimpleTrainer`,它只进行最小的SGD培训,而不进行其他任何操作。如果需要,可以添加自己的钩子。或者:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    #编写类似`tools/plain_train_net.py`的训练循环`.
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    # 还要注意这个类的属性,就像这个文件中的其他函数/类一样,他是不稳定的。因为它是用来表示
    “常见的默认行为”。它只能保证与fastreid中的标准模型和培训工作流一起工作。为了获得更稳定的行为,
    可以使用其他公共api编写自己的训练逻辑。
    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in fastreid.
    To obtain more stable behavior, write your own training logic with other public APIs.
    Attributes:
        scheduler: # 学习策略
        checkpointer (DetectionCheckpointer): # 模型参数检测加载
        cfg (CfgNode):cfg配置文件
    Examples:
    .. code-block:: python
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # 创建记录答应日志的类对象
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        # 对于培训,用DDP包装。但不需要这个来推断。
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )

        super().__init__(model, data_loader, optimizer)

        # 设置学习率衰减策略
        self.scheduler = self.build_lr_scheduler(cfg, optimizer)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        # 加载指定的模型参数
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        # 初始化迭代次数
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        # 创建hooks,并且注册hooks
        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        如果resume==True表示接着之前的迭代次数训练,否则从0开始训练
        If `resume==True`, and last checkpoint exists, resume from it.
        Otherwise, load a model specified by the config.
        Args:
            resume (bool): whether to do resume or not
        """
		......
    def build_hooks(self):
        """
        构建一个默认的hooks列表,包含了timing,checkpointing, lr scheduling, precise BN, writing events
        可以理解为把这些类,或者函数放入到一个容器中,需要他的时候再把他取出来进行调用
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
		......

        return ret

    def build_writers(self):
        """
        主要用于写入log日志等等
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        # 调用父类的训练函数
        super().train(self.start_iter, self.max_iter)
        # 等待训练完成之后进行,返回最后一次的评估结果
        if comm.is_main_process():
            assert hasattr(
                self, "_last_eval_results"
            ), "No evaluation results obtained during training!"
            # verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        根据配置信息cfg构建模型
        Returns:
            torch.nn.Module:
        It now calls :func:`fastreid.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        # logger = logging.getLogger(__name__)
        # logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        根据配置参数构建优化器
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        根据配置参数指定学习率衰减策略
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        构建一个训练数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger(__name__)
        logger.info("Prepare training set")
        return build_reid_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        构建一个测试数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_reid_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, num_query, output_dir=None):
        """
        构建评估器
        """
        return ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        对模型进行评估
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.
        Returns:
            dict: a dict of result metrics
        """
        # 用于log日志的保存
        logger = logging.getLogger(__name__)
        # 检测是evaluators是否为正确的评估器
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]

        # 如果evaluators不为none,则对evaluators的长度进行检测
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )
        # 创建一个字典,用于结果保存
        results = OrderedDict()
        
        # 对多个数据集进行评估 
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            # 进行log打印,并且创建评估数据迭代器
            logger.info("Prepare testing set")
            data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
            
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, num_query)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            # 对单个评估数据集进行推断,并且获得推断结果
            results_i = inference_on_dataset(model, data_loader, evaluator)
            # 保存数据集对应的推断结果
            results[dataset_name] = results_i

        # 如果为主进程,则返回一个评估之后的字典
        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results
            )
            # 使用csv的格式打印评估结果
            print_csv_format(results)

        if len(results) == 1: results = list(results.values())[0]

        return results

    @staticmethod
    def auto_scale_hyperparams(cfg, data_loader):
        r"""
        根据传入的cfg,推算出一些cfg配置参数,如总迭代次数等等
        This is used for auto-computation actual training iterations,
        because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
        so we need to convert specific hyper-param to training iterations.
        """

        cfg = cfg.clone()
		......
        return cfg

代码领读

从上面代码的注释中,我们可以看到如下:

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

这里就是创建数据迭代器,优化器,以及模型的过程。对于 def build_hooks(self) 函数,其会去构建所有的 hooks,如 timing,checkpointing, lr scheduling, precise BN, writing events 等等。这些都是hooks,继承于HookBase。

在这里插入图片描述

你可能感兴趣的:(#,行人重识别,#,目标追踪,ReID,fast-reid,pytorch,Bot,行人重识别)