MMdetection 3d修改验证模型的间隔

前言

最近在用商汤的MMdetection3d来跑模型,有集成代码给自己用是挺好的,但是也正是因为集成代码,改起来比较麻烦。在MMdetection3d中,默认训练一个轮次,验证一个轮次,修改工作流也不行,即workflow = [('train', 1)]改为[('train', 1), ('val', 1)]。这里官方也给出解释了,[('train', 1), ('val', 1)]表示的是在训练一个轮次后,验证一次,计算验证集的损失和精度,而[('train', 1)]则是训练一个轮次,验证一次,给出精度但是不会计算损失。所以,如果想要自己定义运行的方式的话,比如,我觉得验证时间太久了,想改成训练5个轮次再验证一次,特别是一些transformer模型,推理起来很慢。就只能重新写一个hook,但是我觉得应该不用,只是改了验证的间隔而已,官方的代码应该还不至于这么写得这么low。于是我用断点慢慢查看流程,终于发现了入口。
首先,在mmdet3d/apis/train.py中出现了注册验证模型的相关代码:

       eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'

        eval_hook = MMDET_DistEvalHook if distributed else MMDET_EvalHook
        # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
        # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')

就是在这里把验证模型的相关参数传入到下一个函数/类别中,于是点进去eval_hooks.py(入口:eval_hook = MMDET_DistEvalHook if distributed else MMDET_EvalHook),可以看到这个是官方默认的验证参数设置的代码,然后这个类是基于BaseEvalHook写的,也就是说它有BaseEvalHook的一些参数,所以点进去BaseEvalHook看,来到evaluation.py,发现里面有个参数‘interval’,参数解释为:interval (int): Evaluation interval. Default: 1.再看一下类的定义有哪些参数:

    def __init__(self,
                 dataloader,
                 start=None,
                 interval=1,
                 by_epoch=True,
                 save_best=None,
                 rule=None,
                 test_fn=None,
                 greater_keys=None,
                 less_keys=None,
                 out_dir=None,
                 file_client_args=None,
                 **eval_kwargs):

可以看到里面的信息,interval=1,也就是这里控制了训练多少次后再验证模型。

实现

得到入口之后,下一步就是把参数加进去,改变类的默认的值。要修改的地方有两处,第一个是mmdetection3d/configs/下的参数文件,在最后加一句:

eval_interval = 2	# 间隔自己改

就比如,我那个的是groupfree3d的模型,就在mmdetection3d/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O512.py里面加入这句代码。
然后再来到mmdet3d/apis/train.py中,按如下修改:

        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
		# 加这一句话
        eval_cfg['interval'] = cfg.get('eval_interval')

大概是301行左右,修改完之后就可以了。
运行效果如下:

2022-05-18 16:37:37,895 - mmdet - INFO - Epoch [1][600/601]
....
2022-05-18 16:38:13,879 - mmdet - INFO - Epoch [2][50/601]

可以看到,训练完之后并不会验证一次,而是等训练个轮次之后再验证。

结语

对于MMdetection3d这个集成的代码,我保持一种怀疑的态度,毕竟人家论文里面要训练400个轮次的东西,在你这里就成了80个轮次了,可能是我太菜了的原因吧,没有理解太多MMdetection3d代码的实现。
本人水平有限,如果有不妥之处,敬请指出,另外,有些评论问题我实在不知道该怎么回答的,请见谅。

你可能感兴趣的:(深度学习,python,深度学习,开发语言)