知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation

Structured Knowledge Distillation for Semantic Segmentation

论文地址:https://arxiv.org/pdf/1903.04197.pdf

我发现的第一篇语义分割使用知识蒸馏方法的论文,太强了~~本文内容是参考网上论文解读和原论文而成,如有侵权,请联系删除。

语义分割问题可以理解为像素级的分类,而知识蒸馏方法对分类任务有较好的表现,可能是基于此,作者使用知识蒸馏对语义分割下手了~

算法原理简述

知识蒸馏方法本质是用一个大且复杂的模型(teacher)学习到的知识知道小而紧凑的模型(student),目标是让student的网络输出结果尽量与teacher网络的输出结果一致,这样就达到了用小代价获取大性能的目的,岂不美哉。

那teacher network该怎么指导student network进行学习呢?换句话说,该怎么做,才能使student 学习结果向 teacher 学习结果看齐?
我们先设想一下,如果我们是老师教学生,那我们可以给出最后的正确答案给学生,让学生朝着这个目标学习,当然,也可以指导调节学生学习过程中的一些学习细节,目的也是为了学生学习的结果最后能向老师看齐(不知道这个比喻怎么样,,,)。

这里,就涉及到了损失函数,损失函数就是teacher对student的一个指导,标杆,作用是让学生的学习结果越来越逼近teacher的学习结果。

该论文中作者一方面将分割问题理解为像素分类问题,所以自然地使用衡量分类差异的逐像素(Pixel-wise)的损失函数Cross entropy loss,这是在最终的输出结果Score map中计算的。这一损失可以理解为上述teacher给学生的最后正确答案。
另一方面,作者引入了图像的结构化信息损失(可以先理解为teacher对student学习过程中一些细节的指导,也不知道这比喻对不对,先这样吧~~~),
如下图
知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第1张图片

图中作者给出了两种结构化信息的各自衡量方式,
第一种结构化信息,作者认为在语义分割中,预测结果具有自相似性,作者衡量这种结构化信息的方式是Teacher预测的两像素结果和Student网络预测的两像素结果一致。衡量这种损失,作者称之为Pair-wise loss(也许可以翻译为“逐成对像素”损失)。

第二种结构化信息是对图像整体结构相似性的度量,作者引入对抗网络(不是很了解)的思想,设计专门的网络分支对Teacher网络和Student网络预测的结果进行分类,如果两个结果都分为同一类,说明网络收敛OK,这种损失称为Holistic loss(整体损失)。

综上,作者使用了三种损失来进行知识蒸馏,分别是逐像素的损失(Pixel-wise loss,PI)、逐像素对的损失(Pair-wise loss,PA)、整体损失(Holistic loss,HO)。

APPROACH

Pixel-wise distillation

像素级蒸馏损失函数如下:
在这里插入图片描述
知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第2张图片
这个蒸馏模块就是正常蒸馏的思路,我们把教师网络输出的概率拿过来,与学生网络输出的概率做loss,让学生网络逼近教师网络,在图中直观的看就是两个分割图做loss,但其实是概率做loss。

Pair-wise loss,PA

像素对蒸馏损失函数如下:
在这里插入图片描述
知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第3张图片
首先教师网络是不进行优化的,其是已经训好的且好使的分割网络。按先后顺序来说我们先讲Pair-wise distillation,这一个部分作者是受马尔科夫随机场(条件随机场)的启发,作者想要找两两像素之间的相关性以提高网络的效果。At表示教师网络特征图第i个像素与第j个像素之间的相关性,As表示学生网络ij之间相关性,通过下式(平方差)来计算蒸馏loss,让学生网络逼近教师网络。

Holistic loss,HO

整体蒸馏损失函数如下:
在这里插入图片描述
在这里插入图片描述
这里作者利用了GAN的思想,学生网络被看作是为生成器,其输入就是数据集中的RGB图像,输出的是分割图(fake),而教师网络输出的分割图为真实图,分割图与RGB图送入Discriminator net做一个嵌入,这个网络相当于GAN中的判别器会输出一个得分,得分表示RGB与分割图是否匹配,真实分布与错误分布之间计算Wasserstein距离。

整个目标函数由传统的多类交叉熵损失 mc (S) 组成,, 且具有像素化和结构化蒸馏项,一共有4个损失,多类交叉熵损失即是学生网络输出与真实label做普通的loss。最终的loss如下:
在这里插入图片描述
知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第4张图片
知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第5张图片

实验结果

知识蒸馏学习笔记--Structured Knowledge Distillation for Semantic Segmentation_第6张图片
由上图可以看出,使用了知识蒸馏后精度有了明显的提升,结构化信息的蒸馏可以让student network学的更好。

有点粗糙,后面回来补充~

你可能感兴趣的:(模型压缩与加速,深度学习,计算机视觉,深度学习,知识蒸馏,模型压缩加速)