网络蒸馏算法的一些汇总和想法。

最近在做关键点的算法,需要用到网络蒸馏的思路,但是目前大部分都是分类任务的蒸馏,偶然找到一些其他的思路,整理一下。

1.分类任务的蒸馏方法——软标签

传统的分类任务,输出就是one hot 的标签。蒸馏网络让student去学习teacher的软标签,相当于一个更真实的概率分布。这是因为分类任务有天然的概率解释,所以软标签十分合理。

2.检测任务的蒸馏方法,这里只整理分类以外的部分

1)中间层feature间接teaching+最后regression部分加上限约束

参考文章Learning Efficient Object Detection Models with Knowledge Distillation。

中间层由于feature map的尺寸不同,需要用一个1*1的conv进行过渡,会增加一定的参数。存在一定问题。文章在这里也没说清楚,个人感觉可以同时将student和teacher的feature map投影到一个低维空间做比较,参数量会比较小。

最后的regression的teacher只是作为上限约束,保证student学出来的结果优于teacher.并不会让student向teacher学习。

我用caffe实现了一下,发现没啥效果……

2)用GAN的思想,让student和teacher处在同一个域内。

这种方法不是逐个像素的算L1 loss,而是用判别器去模糊地约束,也许会有效果。

参考http://blog.itpub.net/69911376/viewspace-2649773/

网络蒸馏算法的一些汇总和想法。_第1张图片

3.分割任务的蒸馏方法

作者提出了下面三种知识蒸馏的策略:

1) pixel-wise distillation

这里分割任务其实还是与类别有关系,最后还是用softmax求最大类别,所以这一步与分类任务的蒸馏类似,唯一区别就是分割任务是逐像素的分类。

“最简单直接的策略,借鉴分类任务上的知识蒸馏算法,将每个pixel看做分类的单位,独立地进行蒸馏”

进一步,考虑语义分割是有结构的预测任务(每个pixel的结果与它的周围pixel相关),因此提出来下面的两种策略。

2)pair-wise distillation

这种策略领用了pair-wise的马尔科夫随机场框架来增强空间labelling的连续性,目标是对齐简单网络(student)和复杂网络(teacher)中学到的pair-wise的相似度。

这个我觉得很神奇,这一步可以参考文章,主要是分别计算teacher和student的feature map 中每个像素对的相关性,再把teacher 和student的相关性算loss。这里就不需要额外的参数来统一维度了。

3)holistic distillation

这个类似前面说的GAN。

这种策略考虑比pixel-wise和pari-wise更高层次的对齐,利用了对抗式训练策略,使得简单网络和复杂网络的输出没法被区分出来,这样就达到了图像级别的知识蒸馏。

文中提到之前也有人用GAN来做语义分割,目标也是生成器的结果和ground truth没法被判别器区分出来。不过存在一个问题:生成器的输出是连续的(如0-1之间的某个值),而ground truth中的值是独立的(如0或1),因此判别器性能受限。而本文中的方法却没有这个问题,因为ground truth采用的是复杂网络的logits,也是连续的,和生成器的输出可以平等地比较,这是本文一个比较巧妙的点。

参考文章Structured Knowledge Distillation for Semantic Segmentation

以及知乎解读https://zhuanlan.zhihu.com/p/59470026

网络蒸馏算法的一些汇总和想法。_第2张图片

你可能感兴趣的:(机器学习,算法基础)