最近在用商汤的MMdetection3d来跑模型,有集成代码给自己用是挺好的,但是也正是因为集成代码,改起来比较麻烦。在MMdetection3d中,默认训练一个轮次,验证一个轮次,修改工作流也不行,即workflow = [('train', 1)]
改为[('train', 1), ('val', 1)]
。这里官方也给出解释了,[('train', 1), ('val', 1)]
表示的是在训练一个轮次后,验证一次,计算验证集的损失和精度
,而[('train', 1)]
则是训练一个轮次,验证一次,给出精度但是不会计算损失
。所以,如果想要自己定义运行的方式的话,比如,我觉得验证时间太久了,想改成训练5个轮次再验证一次,特别是一些transformer模型,推理起来很慢。就只能重新写一个hook
,但是我觉得应该不用,只是改了验证的间隔而已,官方的代码应该还不至于这么写得这么low。于是我用断点慢慢查看流程,终于发现了入口。
首先,在mmdet3d/apis/train.py
中出现了注册验证模型的相关代码:
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = MMDET_DistEvalHook if distributed else MMDET_EvalHook
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
就是在这里把验证模型的相关参数传入到下一个函数/类别中,于是点进去eval_hooks.py
(入口:eval_hook = MMDET_DistEvalHook if distributed else MMDET_EvalHook
),可以看到这个是官方默认的验证参数设置的代码,然后这个类是基于BaseEvalHook
写的,也就是说它有BaseEvalHook
的一些参数,所以点进去BaseEvalHook
看,来到evaluation.py
,发现里面有个参数‘interval’
,参数解释为:interval (int): Evaluation interval. Default: 1.
再看一下类的定义有哪些参数:
def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
out_dir=None,
file_client_args=None,
**eval_kwargs):
可以看到里面的信息,interval=1
,也就是这里控制了训练多少次后再验证模型。
得到入口之后,下一步就是把参数加进去,改变类的默认的值。要修改的地方有两处,第一个是mmdetection3d/configs/
下的参数文件,在最后加一句:
eval_interval = 2 # 间隔自己改
就比如,我那个的是groupfree3d的模型,就在mmdetection3d/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O512.py
里面加入这句代码。
然后再来到mmdet3d/apis/train.py
中,按如下修改:
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
# 加这一句话
eval_cfg['interval'] = cfg.get('eval_interval')
大概是301行左右,修改完之后就可以了。
运行效果如下:
2022-05-18 16:37:37,895 - mmdet - INFO - Epoch [1][600/601]
....
2022-05-18 16:38:13,879 - mmdet - INFO - Epoch [2][50/601]
可以看到,训练完之后并不会验证一次,而是等训练个轮次之后再验证。
对于MMdetection3d这个集成的代码,我保持一种怀疑的态度,毕竟人家论文里面要训练400个轮次的东西,在你这里就成了80个轮次了,可能是我太菜了的原因吧,没有理解太多MMdetection3d代码的实现。
本人水平有限,如果有不妥之处,敬请指出,另外,有些评论问题我实在不知道该怎么回答的,请见谅。