本文作者:CPFLAME | 编辑:Amusi
原文链接:https://zhuanlan.zhihu.com/p/323814368
本文已由原作者授权,不得擅自二次转载
笔者重构了一版centernet(objects as points)的代码,并加入了蒸馏,多模型蒸馏,转caffe,转onnx,转tensorRT,把后处理也做到了网络前向当中,对落地非常的友好。
放一个centerX多模型蒸馏出来的效果图,在蒸馏时没有用到数据集的标签,只用了两个teacher的model蒸馏同一个student网络。就用大家的老婆来做demo吧。
不感兴趣的童鞋可以收藏一下笔者的表情包,如果觉得表情包好玩,跪求去github点赞。
代码地址:
https://github.com/CPFLAME/centerX
centernet是我最喜欢的检测文章之一,没有anchor,没有nms,结构简单,可拓展性强,最主要的是:落地极其方便,选一个简单的backbone,可以没有bug的转成你想要的模型(caffe,onnx,tensorRT)。并且后处理也极其的方便。
但是centernet原版的代码我初看时有点吃力,但也没有重构的想法,过了一些时日后我找到了centernet-better和centernet-better-plus,于是把他们的代码白嫖了过来然后自己完善一下,形成对我友好的代码风格。(当然剽窃最多的其实是fast reID和detectron2)
由于本人不喜欢写纯技术方面的博客,也不想写成一篇纯PR稿(从本科开始就深恶痛觉写实验报告),更不想让人觉得读这篇文章是在学习,所以本篇文章不太正经,也没有捧一踩一的操作,跟别人的宣传稿不太一样。
毕竟代码写的不是打打杀杀,而是人情世故,真学东西还得看其他人的文章,看我的也就图一乐。
一般来说读文章的人点进来都会带着这样一个心理,我为什么要用centerX,明明我用别的框架用的很顺利了,转过来多麻烦你知道吗,你在教我做事?
https://github.com/CPFLAME/centerX
受到老领导道家思维编程的启发,centerX的trick里面也贯彻了一些具有中国特色社会主义的中心主题思想。
这个方面没有什么好说的,也没有做到和其他框架的差异化,只是在detectron2上对基础的centernet进行了复现而已,而且大部分代码都是白嫖自centernet-better和centernet-better-plus,就直接上在COCO上的实验结果吧。
centerX_KD是用27.9的resnet18作为学生网络,33.2的resnet50作为老师网络蒸馏得到的结果,详细过程在在下面的章节会讲。
大嘎好,我是detection。我时常羡慕的看着隔壁村的classification,embedding等玩伴,他们在蒸馏上面都混得风生水起,什么logits蒸馏,什么KL散度,什么Overhaul of Feature Distillation。每天都有不同的家庭教师来指导他们,凭什么我detection的教育资源就很少,我detection什么时候才能站起来!
造成上述的原因主要是因为detection的范式比较复杂,并不像隔壁村的classification embedding等任务,开局一张图,输出一个vector:
我编不下去了
我们再来回头看看centernet的范式,哦,我的上帝,多么简单明了的范式:
网络输出三个头,一个预测中心点,一个预测宽高,一个预测中心点的偏移量
没有复杂的正负样本采样,只有物体的中心点是正样本,其他都是负样本
这让笔者看到了在detection上安排家庭教师的希望,于是我们仿照了centernet本来的loss的写法,仿照了一个蒸馏的loss。具体的实现可以去code里面看,这里就说一下简单的思想。
对于输出中心点的head,把teacher和student输出的head feature map过一个relu层,把负数去掉,然后做一个mse的loss,就OK了。
对于输出宽高和中心点的head,按照原centernet的实现是只学习正样本,在这里笔者拍脑袋想了一个实现方式:我们用teacher输出中心点的head过了relu之后的feature作为系数,在宽高和中心点的head上所有像素点都做L1 loss后和前面的系数相乘。
在蒸馏时,三个head的蒸馏loss差异很大,需要手动调一下各自的loss weight,一般在300次迭代后各个蒸馏loss在0~3之间会比较好。
所以在之前我都是300次epoch之后直接停掉,然后根据当前loss 预估一个loss weight重新开始训练。这个愚蠢的操作在我拍了另外一次脑袋想出共产主义loss之后得以丢弃。
在模型蒸馏时我们既可以在有标签的数据上联合label的loss进行训练,也可以直接用老师网络的输出在无标签的数据集上蒸馏训练。基于这个特性我们有很多妙用
当在有标签的数据上联合label的loss进行训练时,老师训N个epoch,学生训N个epoch,然后老师教学生,并保留原本的label loss再训练N个epoch,这样学生的mAP是训出来最高的。
当在无标签的数据集上蒸馏训练时,我们就跳出了数据集的限制,先在有标签的数据集上老师训N个epoch,然后老师在无标签的数据集上蒸馏学生模型训练N个epoch,可以使得学生模型的精度比baseline要高,并且泛化性能更好。
之前在centernet的source code上还跑过一个实验,相同的网络,自己蒸馏自己也是可以涨点的。在centerX上我忘记加进去了。
我们拉到实验的部分,上述的瞎比猜想得到验证。
看到蒸馏效果还可以,可以在不增加计算量的情况下无痛涨点,笔者高兴了好一阵子,直到笔者在实际项目场景上遇到了一个尴尬地问题:
因为数据集A里面可能会有大量的未标注的B,B里面也会有大量的未标注的A,直接放到一起训练肯定不行,网络会学傻。
在笔者再次拍了拍脑袋后,发挥了我最擅长的技能:白嫖。想到了这样一个方案:
笔者分别在人体和车,以及人体和人脸上做了实验。数据集为coco_car,crowd_human,widerface.
笔者在训练centerX时,出现过这样一个问题,设置合适的lr时,训练的一切都那么自然又和谐,而当我lr设置大了以后,有时候会训到一半,网络直接loss飞涨然后mAP归零又重新开始往上爬,导致最后模型的mAP很拉胯。对于这种情况脾气暴躁的我直接爆了句粗口。
骂完了爽归爽,问题还是要解决的,为了解决这个问题,笔者首先想到笔者的代码是不是哪里有bug,但是找了半天都没找到,笔者还尝试了如下的方式:
看来这个bug油盐不进,软硬不吃。训练期间总会出现某个时间段loss突然增大,然后网络全部从头开始训练的情况。
这让我想到了内卷加速,资本主义泡沫破裂,经济大危机后一切推倒重来。这个时候才想起共产主义的好,毛主席真是永远滴神。
既然如此,咱们一不做二不休,直接把蛋糕给loss们分好,让共产主义无产阶级的光照耀到它们身上,笔者一气之下把loss的大小给各个兔崽子head们给规定死,具体操作如下:
接下来就是实验部分看看管不管用了,于是笔者尝试了一下之前崩溃的lr,得益于共产主义的好处,换了几个数据集跑实验都没有出现mAP拉胯的情况了,期间有几次出现了loss飞涨的情况,但是在共产主义loss强大的调控能力之下迅速恢复到正常状态,看来社会主义确实优越。同时笔者也尝试了用合适的lr,跑baseline和共产主义loss的实验,发现两者在±0.3的mAP左右,影响不大。
笔者又为此高兴了好一段时间,并且发现了共产主义loss可以用在蒸馏当中,并且表现也比较稳定,在±0.2个mAP左右。这下蒸馏可以end2end训练了,再也不用人眼去看loss、算loss weight、停掉从头训了。
这个部分的代码都在code的projects/speedup中,注意网络中不能包含DCN,不然转码很难。
centerX中提供了转caffe,转onnx的代码,onnx转tensorRT只要装好环境后一行指令就可以转换了,笔者还提供了转换后不同框架的前向代码。
其中笔者还找到了centernet的tensorRT前向版本(后续笔者把它称为centerRT),在里面用cuda写了centernet的后处理(包括3X3 max pool和topK后处理)。笔者在转完了tensorRT之后想直接把centerRT白嫖过来,结果发现还是有些麻烦,centerRT有点像是为了centernet原始实现定制化去写的。这就有了以下的问题
这次笔者拍碎了脑袋都没想到怎么白嫖,于是在献祭了几根珍贵的头发之后,强行发动了甩锅技能,把后处理操作都扔给神经网络,具体操作如下:
def centerX_forward(self, x):
x = self.normalizer(x / 255.)
y = self._forward(x)
fmap_max = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)(y['cls'])
keep = (y['cls'] - fmap_max).float() + 1e-9
keep = nn.ReLU()(keep)
keep = keep * 1e9
result = y['cls'] * keep
ret = [result,y['reg'],y['wh']] ## change dict to list
return ret
onnx中可视化如下:
值得注意的是上述骚操作在转caffe的时候会报错,所以不能加。如果非要添加上去,得在caffe的prototxt中自行添加scale层,elementwise层,relu层,这个笔者没有实现,大家感兴趣可以自行添加。
考虑到大家需要向上管理,笔者写几个可以涨点的东西
除了以上的在精度方面的优化之外,其实笔者还想到很多可以做的东西,咱们不在精度这个地方跟别人卷,因为卷不过别人,检测这个领域真是神仙打架,打不过打不过。我们想着把蛋糕做大,大家一起有肉吃
其实有太多的东西想加到centerX里面去了,里面有很多很好玩的以及非常具有实用价值的东西都可以去做,但是个人精力有限,而且刚开始做centerX完全是基于兴趣爱好去做的,本人也只是渣硕,无法full time扑到这个东西上面去,所以上述的优化方向看看在我有生之年能不能做出来,做不出来给大家提供一个可行性思路也是极好的。
非常感谢廖星宇,何凌霄对centerX代码,以及发展方向上的贡献,感谢郭聪,于万金,蒋煜襄,张建浩等同学对centerX加速模块的采坑指导。
再放一遍自己的github
https://github.com/CPFLAME/centerX
以及感谢如下杰出的工作
https://github.com/xingyizhou/CenterNet
https://github.com/facebookresearch/detectron2
https://github.com/FateScript/CenterNet-better
https://github.com/lbin/CenterNet-better-plus
https://github.com/JDAI-CV/fast-reid
https://github.com/daquexian/onnx-simplifier
https://github.com/CaoWGG/TensorRT-CenterNet
在CVer微信公众号后台回复:CenterX,即可下载上述项目源代码
后台回复:目标检测二十年,即可下载39页的目标检测最全综述,共计411篇参考文献。
后台回复:CVPR2020,即可下载代码开源的论文合集
后台回复:ECCV2020,即可下载代码开源的论文合集
后台回复:YOLO,即可下载YOLOv4论文和代码
扫码添加CVer助手,可申请加入CVer-目标检测 微信交流群,目前已汇集4000人!涵盖2D/3D目标检测、小目标检测、遥感目标检测等。互相交流,一起进步!
一定要备注:如目标检测+地点+学校/公司+昵称,根据格式备注,可更快被通过且邀请进群