浅析mmdetection demo

一、import做了啥
每个import都会执行对应的模块或者py文件,包括其中定义的函数和装饰器。

#自己编写的用来训练的.py文件
from mmdet.datasets.builder import DATASETS
#mmdet/datasets/builder.py
#这个地方就需要注意了,mmcv.runner是个文件夹,这里被当作包来引用,也就是会执行mmcv/runner/__init__.py
from mmcv.runner import get_dist_info
#mmcv/runner/__init__.py
from .base_runner import BaseRunner
#mmcv/runner/base_runner.py
#这个地方也需要注意了,.hooks是个文件夹,这里被当作包来引用,也就是会执行
mmcv/runner/hooks/__init__.py,上面的代码。
from .hooks import HOOKS, Hook
#mmcv/runner/hooks/__init__.py代码中的每个引用都会执行内部的装饰器函数,也就是把对应的类名和类添加到HOOKS._module_dict中。

初始化下列全局变量,实例化N个Registry 类,每个类内部维护的是一个全局 key-value 对。

#mmdet/datasets/builder.py
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
#mmdet/models/builder.py
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
#mmcv/runner/hooks/hook.py
HOOKS = Registry('hook')

二、cfg = Config.fromfile做了啥
py文件转成dict类型的配置参数
三、build_dataset做了啥
创建数据类KittiTinyDataset,很多属性都是父类的,包括pipeline。

#根据配置文件创建数据
datasets = [build_dataset(cfg.data.train)]
#数据类的实例化
class KittiTinyDataset(CustomDataset)
#父类CustomDataset属性,self.pipeline,Compose类的实例化。self.pipeline()会调用Compose类的__call__函数。属性self.data_infos。
self.data_infos = self.load_annotations(self.ann_file)
self.pipeline = Compose(pipeline)

四、build_detector做了啥
实例化模型

model = build_detector(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
DETECTORS.build(
        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
#实例化FasterRCNN类,进而实例化模型的整体架构
 class FasterRCNN(TwoStageDetector):
 class TwoStageDetector(BaseDetector):
 class BaseDetector(BaseModule, metaclass=ABCMeta):

五、train_detector做了啥

#将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # `num_gpus` will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
            runner_type=runner_type,
            persistent_workers=cfg.data.get('persistent_workers', False))
        for ds in dataset
    ]
optimizer = build_optimizer(model, cfg.optimizer)

#runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,
#也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。\
#runner配置文件内容runner = dict(type='EpochBasedRunner', max_epochs=12)
#build_runner最后也是调用build_from_cfg,根据配置文件去实例化EpochBasedRunner/IterBasedRunner
runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))
# register_training_hooks注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。
#最后也是调用build_from_cfg,根据配置文件去实例化各个Hook类
 runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))
 #把实例化的hook插入到    self._hooks
# register_training_hooks、register_lr_hook等、register_hook 都是BaseRunner类的方法。_hooks是BaseRunner类的属性
self.register_hook(hook, priority='VERY_HIGH')
self._hooks.insert(i + 1, hook)

#加载模型
    runner.load_checkpoint(cfg.load_from)
 # runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练
    runner.run(data_loaders, cfg.workflow)
    #根据mode内容,epoch_runner=train或者epoch_runner=val
 epoch_runner = getattr(self, mode)
 epoch_runner(data_loaders[i], **kwargs)
 #训练一轮数据
def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(self.data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(self.data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        self.run_iter(data_batch, train_mode=True, **kwargs)
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1

你可能感兴趣的:(mmdtection,python,深度学习)