MMSegmentation 训练测试全流程

MMSegmentation 训练测试全流程

  • 1.按照执行顺序的流程梳理
      • Level 0: 运行 Shell 命令:
      • Level 1: 在 tools/train.py 内:
      • Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:
      • Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterBasedRunner 类的 run 函数内部:
      • Level 4: 转进到 IterBasedRunner 类的 train 函数内部
      • Level 5: 转进到 EvalHook 类实例的 after_train_iter 函数内部:
  • 4.函数说明:
  • 5.疑问解答
  • 参考链接:

括号的部分可以不看!是debug经过的内容,有些事调用了mmcv库的函数,只想看看流程不需要细看!

1.按照执行顺序的流程梳理

Level 0: 运行 Shell 命令:

  • python tools/train.py ${CONFIG_FILE [optional arguments]

Level 1: 在 tools/train.py 内:

  • 读取各种 config: cfg = Config.fromfile(args.config)
  • 创建 model: model = build_segmentor(cfg.model, train_cfg, test_cfg)
  • 创建 training dataset: datasets = [build_dataset(cfg.data.train)]()
    • 通过Config类的__getattr__函数:value = super(ConfigDict, self).__getattr_获取数据和数据增强信息并返回value
    • 转到mmseg/datasets/builder.py内的build_dataset函数,获取dataset:dataset = build_from_cfg(cfg, DATASETS, default_args)
    • 转到/usr/local/lib/python3.8/dist-packages/mmcv/utils/registry.py内的build_from_cfg函数:
      • args = cfg.copy()
      • 获取数据格式类型:obj_type = args.pop('type'),比如obj_type:ADE20KDatase
      • 通过数据格式obj_type获得类obj_cls = registry.get(obj_type),比如
      • 获取return obj_cls(**args)
    • (转到/usr/lib/python3.8/typing.pyGeneric类的__new__函数:obj = super().__new__(cls))
    • 转到mmseg/datasets/ade.py中的ADE20KDataset类的__init__函数:super(ADE20KDataset, self).__init__(**
    • 转到mmseg/datasets/custom.pyADE20KDatase类继承的CustomDataset
      • 调用loading.py中LoadAnnotations类进行初始化,获得image和mask的地址等信息,并获取image和mask名字的dict:self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,self.ann_dir,self.seg_map_suffix, self.split)
      • 实例对象做运算时,就会调用CustomDataset类中的__getitem__()__:self.prepare_train_img(idx)
      • 调用prepare_train_img函数:self.pipeline(results),调用mmseg/datasets/pipelines/loading.pyLoadImageFromFile类和其他数据增强
  • 创建 validation dataset: datasets.append(build_dataset(val_dataset))
  • model, data, config 喂给训练函数: train_segmentor(model, datasets, cfg)

Level 2: 转进到 mmseg.apis 模块的 train_segmentor 函数内:

  • 创建 dataloader: data_loaders = [build]()_dataloader(dataset, config)]
  • model 搬到 GPU 上去: model = MMDataParallel(model.cuda(), cfg)
  • 创建 optimizer: optimizer = build_optimizer(model, cfg)
  • 创建 runner: runner = build_runner(model, cfg, optimizer)
  • 给 runner 注册 training hooks: runner.register_training_hooks(cfg)
  • 给 runner 注册 validation hooks: runner.register_hook(eval_hook(val_dataloader, eval_cfg))
    • 这个 eval_hook 是 EvalHook 类实例, 其重写了 after_train_iterafter_train_epoch 两个方法, 在 IterBasedRunner 中用的是 after_train_iter
  • 开始训练 runner.run(data_loaders, cfg.workflow)

Level 3: 转进到 mmcv/runner/iter_based_runner.py 内的 IterBasedRunner 类的 run 函数内部:

  • Training 模式, mode = 'train', i = 0, 运行 iter_runner(iter_loaders[i](), **kwargs)
    • 实质上是在运行 IterBasedRunner类的 train 函数: train(iter_loaders[0](), **kwargs)
    • while self.iter < self._max_iters: 可以看到, 这个 train 函数一共会被调用 self._max_iters
    • 从中也可以看到这个 train 函数其实只负责做一个 batch 数据的 forward 计算
  • Validation 模式, 此处其实没有运行
    • mmseg 的所有 setting 都是 workflow = [('train', 1)]
    • 实际上的 validation 是通过在 after_train_epoch 节点调用 EvalHook 对象的 after_train_iter方法实现的。

Level 4: 转进到 IterBasedRunner 类的 train 函数内部

  • 读取一个 batch 的数据: data_batch = next(data_loader)
  • 调用 model 的 train_step 函数计算 loss: outputs = self.model.train_step(data_batch)
  • 尝试选择性进行 validationself.call_hook('after_train_iter')
    • 实质上是调用 EvalHook 类实例的 after_train_iter 函数;

Level 5: 转进到 EvalHook 类实例的 after_train_iter 函数内部:

  • 如果当前迭代数不能够被 interval 整除, 就不做 validation: if not self.every_n_iters(runner, self.interval): return
  • 如果能被整除, 计算一下 validation set 上的结果: results = single_gpu_test(model, dataloader)
    • 这一步就是 enumerate 一下 data_loader, 对于每个 batch 都用 model forward 一下, 把 result 都 append 起来得到一个 list results, 就不再展开了
  • 对于分割结果再调用 datasetevaluate 函数计算一下 mIoU, mDice, mFscoremetric 数值
    • 其实就是通过调用下 mmseg.core 里面的 eval_metrics 函数调用 total_intersect_and_union 函数计算下上述数值

4.函数说明:

  • self.pipeline = Compose(pipeline)

    • Compose:把函数组合起来,每个函数的返回值是下一个函数的参数
  • print_log(f’Loaded {len(img_infos)} images’, logger=get_root_logger())

    • print_log:打印日志
  • target = torch.where(target == ignore_index, target.new_tensor(0), target)

    • torch.where:查找 target 中值为ignore_index(255)的值转为0,
    • new_tensor:target.new_tensor是将target的值copy一份,不共享内存,new_tensor(0)指值为0同样size矩阵

5.疑问解答

  • CustomDataset类中pre_eval函数的ignore_index=255是起什么作用的? 是不计算255的loss吗
    • mmseg/core/evaluation/metrics.py函数中找到了答案
    • intersect_and_union函数中计算IOU的时候,将ignore_index=255的值忽略掉:mask = (label != ignore_index),相当于不计算背景的准确率,获取到的相当于是召回率Recall
    • 需要注意的是,其中reduce_zero_label=True时,是将像素值为0的转为255:label[label == 0] = 255,会在mask = (label != ignore_index)处一并忽略
  • 注意:255值在标注员工标注过程中代表不需要标注的区域,相当于背景,在需要标注的区域,背景值是0

参考链接:

【1】MMSegmentation 训练测试全流程及其关键节点

你可能感兴趣的:(目标检测&实例分割,深度学习,python,mmsegmentation,pytorch,图像分割)