mmdetection - 训练过程之train_detector

下面是train_detector的主干,我删除了异常判断、版本兼容、分布式训练等内容,下面列出来的是我认为比较重要的部分。

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']
        
#DataLoader,是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,
#一般只要是用PyTorch来训练模型基本都会用到该接口,该接口的目的:将自定义的
#Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练
    data_loaders = [
        build_dataloader(
            ds,
            cfg.data.samples_per_gpu,
            cfg.data.workers_per_gpu,
            # `num_gpus` will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
            runner_type=runner_type,
            persistent_workers=cfg.data.get('persistent_workers', False))
        for ds in dataset
    ]

#构建优化器,optimizer目的:优化SGD,训练快速收敛并且保证准确率
   optimizer = build_optimizer(model, cfg.optimizer)
# build runner runner(实现在mmcv中)主要是用来管理模型训练时的生命周期,负责 OpenMMLab 中所有框架的训练过程调度,也就是管理何时执行resume、logger、save checkpoint、学习率更新、梯度计算BP等常见操作。
    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))
    # register hooks 注册多个hook,在训练过程中调用,学习率设置、优化器设置、模型保存、日志打印等。
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))
#加载模型
    runner.load_checkpoint(cfg.load_from)
 # runner.run-> runner.train-> runner.run_iter->self.model.train_step,进行模型训练
    runner.run(data_loaders, cfg.workflow)

你可能感兴趣的:(mmdtection,pytorch,深度学习,机器学习)