Knowledge Distillation论文阅读之(1):Structured Knowledge Distillation for Semantic Segmentation

文章目录

  • abstract
  • Introduction
  • Related Work
    • semantic segmentation(语义分割方面)
    • Knowledge distillation(知识蒸馏方面)
    • Adversarial learning(对抗学习方面)
  • Approach
    • 3.1 Structured Knowledge Distillation(结构化的知识蒸馏)
    • 3.2 Optimization(优化)

abstract

  • 这篇文章了使用 “知识蒸馏(KD)” 的方法利用大规模的语义分割网络来实现小规模的语义分割网络
  • 我们非常直接地利用图像分类的方式对图片的每个像素进行 “知识蒸馏”
  • 我们进一步地从大型分割网络中提取结构化的信息到小的分割网络中;因为语义分割是一个结构化的预测问题
  • 本文提出了两种结构化的蒸馏提取方案:
    • pair-wise(成对) 的蒸馏方案:在成对的 student 网络的图片和 teacher网络的图片中提炼出成对的相似性(similarity)
    • holistic(整体的)蒸馏方案:利用 GAN 网络去提取整体的信息
  • 在三个场景处理数据集:Cityscapes, Camvid and ADE20K 上证明了本文的先进性

Introduction

  • 语义分割解决的问题是对输入图片的每一个像素点进行标签预测(label prediction);是计算机视觉的一个重要工作;在生活中有很多的应用:自动驾驶、视频监视、虚拟现实。。。
  • 深度神经网络对语义分割有很多解决方法,比如全卷积神经网络(fully-convolutional neural networks(FCN);还包括随后的很多研究,例如:DeepLab,PSPNet,OCNet,RefineNet 和 DenseASPP 都取得了很好的效果和很高的精度,但是其代价就是庞大的模型和计算量
  • 目前,轻便的神经网络模型、轻量级的运算和高精度的分割已经成为时代的要求,因为要搭载到手机等终端设备上,所以这样的要求是大势所趋;目前尝试的方向都集中在设计用于语义分割的轻量化网络,或者借鉴分类网络的设计,例如:ENet、ESPNet 、 ERFNet、ICNet;这些论文的研究重点是通过“复杂型”网络的帮助来训练紧凑型(袖珍型)的轻便语义分割网络。
  • 我们研究了用于训练 “紧凑” 语义分割网络的知识蒸馏策略,该策略在分类任务中被证明是有效的。
  • **pixel-wise 方案:**作为一种直观的方案,我们将语义分割问题简单地看作是许多单独的像素分类问题,然后直接将知识蒸馏方案应用到像素级;这个简单的方案,我们称之为 “像素级提取(pixel-wise distillation)方案” ,它将复杂网络 (Teacher) 产生的相应像素的类概率 (class probability) 转移到紧凑网络(Student) 中。
  • pair-wise 方案: 考虑到语义分割是一个结构化的预测问题,我们提出了 “结构化知识蒸馏(structured knowledge distillation)”和“结构信息传递(transfer the structured information)” 两个步骤的组合操作;这种成对的精算方案是受广泛研究的成对马尔可夫随机场框架的启发,以增强空间标记的连续性,其目标是对从紧凑网络和繁琐网络计算的像素点之间的成对相似性进行对齐(align)。
  • 整体精馏方案 (holistic distillation): 的目标是对齐(align)紧凑的分割网络和复杂的分割网络生成的分割图(segmentation maps)之间的高阶一致性,而这是前面提到的 pixel-wise 和 pair-wise 蒸馏方案所不具备的
  • 为了达到这个目的,我们优化了目标函数(损失函数),结合了传统的多分类交叉熵损失和知识蒸馏系统的损失;本文的主要贡献如下:
    • 研究了知识蒸馏策略用来训练紧凑分割网络
    • 我们提出了两个结构化的知识蒸馏方案,pair-wise(成对蒸馏)方案和 holistic(整体蒸馏)方案,增强紧凑网络和复杂网络输出之间的“两两对应(pair-wise)”特性和两个网络之间高阶特征(higher-order consistency)的一致性。
    • 我们证明了我们研究的有效性和先进性,我们与最近新提出的具有最好效果的ESPNet,MobileNetV2-Plus 和 ResNet18 在三个benchmark datasets:Cityscapes, CamVid and ADE20K上进行验证。

Related Work

semantic segmentation(语义分割方面)

  • 深度卷积神经网络是主流的语义分割方法 FCN,DeConvNet,U-Net
  • 为了提高网络的工作能力和相应的分割性能,已经提出了各种方案,例如:更强的主干网络:GoogleNets,ResNets 和 DenseNets,这些网络表现出了更好的分割结果;通过空洞卷积(dilated convolutions)的方式或者多路径优化网络(multi-path refine networks)也做出了很好的分割效果;利用多尺度的上下文(multi-scale context)例如在 PSPNet 中利用空洞卷积(dilated convolutions)、金字塔池化模块(pyramid pooling modules),在 DeepLab 中使用空间金字塔池化模块(astrous spatial pyramid pooling),目标对象上下文(object context)也对语义分割任务有很好的帮助。
  • Lin等人将深度模型(deep models)与结构化输出(structured output learning)学习相结合,实现语义分割的任务
  • 对于一些复杂的神经网络能够实现高精度的语义分割任务,高效的分割网络迅速吸引了人们的大量关注并在实际中大量应用例如:手机app;
  • 利用因子分解技术(factorization)对卷积运算进行加速,从而实现网络的轻量化设计是很多工作的重心。
  • 受 Rethinking the inception architecture for computer vision 启发,ENet[31]集成了多个加速的方法,包括:多分支模块(multi-branch modules)、早期特征图分辨率下采样(early feature map resolution down-sampling)、小尺寸的解码器(small decoder size)、滤波器张量因子分解(filter tensor factorization)等
  • SQ采用压缩fire modules 和并行的空洞卷积来实现有效的分割。
  • ESPNet 提出了一种有效率的空间金字塔模型,这个模型基于滤波器因式分解(filter factorization)技术:使用逐点卷积(point-wise convolutions)和空洞卷积的空间金字塔(spatial pyramid of dilated convolutions)的方式来代替了普通的卷积操作。
  • 高效的分类网络例如:MobileNet, ShuffleNet, and IGCNet 都是用来加速语义分割任务而被设计出来的。此外,ICNet (image cascade network) 利用了处理低分辨率图像的效率和对高分辨率图像的高推理质量,实现了效率和精度的平衡。

Knowledge distillation(知识蒸馏方面)

  • 知识蒸馏是一种把信息和知识从复杂网络迁移到紧凑网络来提高紧凑网络表现的方法
  • 他在图像分类的应用是:把复杂网络输出的概率分布作为软标签(soft target)来训练紧凑网络,或者用来转换中间状态的特征图(intermediate feature maps)
  • 当然还有其他的一些应用:包括目标识别,行人识别等
  • 最近独立开发的语义分割应用与我们的方法有关;它主要是分别提取每个像素的类概率(像我们的逐像素蒸馏(pixel-wise distillation))和每个局部patch的标签的中心-周围差异(center-surrounding differences)
  • 相比之下,我们侧重于提取结构化知识:成对蒸馏(pair-wise distillation),它传递所有像素对之间的关系,而不是局部 patch中的关系;我们还做了整体蒸馏(holistic distillation)的工作,它负责传递捕获高阶信息的整体知识。

Adversarial learning(对抗学习方面)

  • GAN 网络在文本生成、图像生成领域有非常广泛的研究,conditional version 被成功的应用到 image-to-image 的转换中,包括:风格的转换、图像修复、图像着色以及文本到图片的转化中
  • 对抗学习的思想(adversarial learning)也同样被应用到姿态估计(pose estimation)中,使人体姿态估计的结果与 ground-truth 图没有区别。
  • GAN 网络的一个挑战是生成器(generator)的连续输出(continuous output)和离散的 labels 之间的不匹配,使GAN的鉴别器的成功很受局限
  • 但是在我们的文章中,GAN 不存在这个问题,因为我们使用的是 teacher 网络的softmax 前一层的输出(logits,这是个连续的概率分布),我们使用对抗学习的方式使得网络能够在复杂网络产生的语义分割图和紧凑网络产生的语义分割图之间对齐(alignment)。

Approach

  • 语义分割的任务是把每一个像素进行分类(预测每个像素的标签类别),假设现在一共有 C C C 类标签;输入网络的 RGB 图像 I I I 的尺寸是 W × H × 3 W×H×3 W×H×3;然后计算出的特征图(feature map) 用 F F F 来表示,它的尺寸是 W ′ × H ′ × N W^{'}×H^{'}×N W×H×N N N N是通道数;最后,一个分类器被用来从特征图 F F F 中计算出分割图(segmentation map) Q Q Q,分割图 Q Q Q 的尺寸为 W ′ × H ′ × C W^{'}×H^{'}×C W×H×C,然后将其进行上采样,把尺寸还原到 W × H W×H W×H 跟输入图的尺寸一样

3.1 Structured Knowledge Distillation(结构化的知识蒸馏)

  • 我们采用知识蒸馏策略将笨拙分割网络 T T T 的知识转移到紧凑分割网络 S S S 中,以更好地训练紧凑分割网络;为了将结构化的知识从笨重的网络中转移到紧凑网络中,我们提出了两种结构化的知识蒸馏方案,即成对的知识蒸馏(pair-wise distillation)和整体的知识蒸馏(holistic distillation);其结构如图2所示。
    Knowledge Distillation论文阅读之(1):Structured Knowledge Distillation for Semantic Segmentation_第1张图片
    figure2:(a)成对的知识蒸馏策略(pair-wise),(b)像素级的知识蒸馏(pixel-wise),(c)整体蒸馏策略;训练过程中,我们固定复杂的网络作为我们的teacher network,只优化 student net 和 discriminator鉴别器;具有紧凑结构的 student 网络将用三个蒸馏损失(distillation terms)和一个交叉熵项(cross_entropy)进行训练。

Pixel-wise distillation

  • 我们将分割问题视为一个单独的像素标记问题的集合,并直接使用知识蒸馏来对齐从紧凑网络产生的每个像素的类标记的概率。

  • 我们采用了一种显著的方法:使用从复杂网络模型中产生的软标签(soft target),用来训练紧凑网络,我们定义的损失函数如下:

    l p i ( S ) = 1 W ′ × H ′ ∑ i ∈ R K L ( q i s ∣ ∣ q i t ) l_{pi}(S)={\frac{1}{W^{'}×H^{'}}}{\sum_{i∈R}KL(q_i^{s}||q_i^{t})} lpi(S)=W×H1iRKL(qisqit) ( 1 ) (1) (1)

  • q i s q_i^{s} qis 表示紧凑网络 S S S 产生的第 i i i 个像素的类概率; q i t q_i^{t} qit 表示复杂网络 T T T 产生的第 i i i 个像素的类概率; K L ( ⋅ ) KL(·) KL() 代表的是 Kullback-Leibler在两个概率之间的散度即:信息熵差值; R = { 1 , 2 , ⋅ ⋅ ⋅ ⋅ , W ′ × H ′ } R=\{{1,2,····,{W^{'}×H^{'}}}\} R={1,2,,W×H}
    在这里插入图片描述

这个公式表示的是,对于 Teacher 网络中产生的分割图以及 Student 产生的分割图,对于他们的每个对应的像素点,求他们之间的相对熵(KL)然后求和代表的整个语义分割图中他们所有的信息熵的差距;至于前面的系数 1 W ′ × H ′ {\frac{1}{W^{'}×H^{'}}} W×H1 则是一个取平均的过程。

Pair-wise distillation

  • 受马尔可夫随机场成对框架的启发(该框架被广泛用于改善空间标记的连续性),我们提出了转移成对像素信息(pair-wise)的方法,特别是像素之间成对的相似性(pair-wise similarities among pixels)。

  • a i j t a_{ij}^t aijt 代表从复杂网络 T T T 中产生的第 i i i 个像素点,和第 j j j 个像素点的相似度(similarity), a i j s a_{ij}^s aijs 代表从复杂网络 S S S 中产生的第 i i i 个像素点,和第 j j j 个像素点的相似度(similarity);我们采用平方差分来表示成对的相似性扩散损失

    l p a ( S ) = 1 ( W ′ × H ′ ) 2 ∑ i ∈ R ∑ j ∈ R ( a i j s − a i j t ) 2 l_{pa}(S)={\frac{1}{(W^{'}×H^{'})^2}}{\sum_{i∈R}}{\sum_{j∈R}(a_{ij}^{s}-a_{ij}^{t})}^2 lpa(S)=(W×H)21iRjR(aijsaijt)2 ( 2 ) (2) (2)

    这个公式在 teacher 网络和 student 网络内部先计算了各自像素点之间的相似度;然后把两个图的相似度做差,反映出来的是两个图内部结构是否很相似,如果 student 网络学习的很好那么它分割出来的图像应该和 teacher 做出来的差不多,即图像内部的 pixel 之间的相似性会和 teacher 网络内部 pixel 之间的相似性相近;这种情况下, ( a i j s − a i j t ) {(a_{ij}^{s}-a_{ij}^{t})} (aijsaijt) 这个差值就会很小,相反,就会比较大。

  • a i j = f i ⊤ f j ∣ ∣ f i ∣ ∣ 2 ∣ ∣ f j ∣ ∣ 2 a_{ij}={\frac{f_i^⊤f_j}{||f_i||_2||f_j||_2}} aij=fi2fj2fifj 这里使用的相似度计算方式是 cos 相似度。

  • 我个人理解的公式 + 图放在这里:

Holistic distillation

  • 我们对齐复杂网络和紧凑网络分割图像之间的高阶关系(higher-order relations);以分割图的整体映射(holistic embedding) 作为表示。

  • 我们使用对抗学习策略来解决整体的知识蒸馏问题

  • 紧凑网络被看做是生成器 generator,它的输入时 RGB 图像 I I I,通过 I I I 预测最后的语义分割图 Q s Q^s Qs 作为 fake sample(假样本);我们希望 Q s Q^s Qs Q t Q^t Qt 相似度越高越好( Q t Q^t Qt 是 teacher 通过 I I I 预测的语义分割图,被作为是 real sample(真样本))

  • 在这里我们使用了 Wasserstein 距离来评估 real 分布和 fake 分布;因为之前提到过了,无论是 teacher 还是 student 最后的输出都是一个概率分布。

    l h o ( S , D ) = E Q s ∼ p s ( Q s ) [ D ( Q s ∣ I ) ] − E Q t ∼ p t ( Q t ) [ D ( Q t ∣ I ) ] l_{ho}(S,D)={\mathbb E_{Q^s\sim p_s(Q^s)}[D(Q^s|I)]-\mathbb E_{Q^t\sim p_t(Q^t)}[D(Q^t|I)]} lho(S,D)=EQsps(Qs)[D(QsI)]EQtpt(Qt)[D(QtI)] ( 3 ) (3) (3)

  • E {\mathbb E} E 是期望算子

  • D ( ⋅ ) D(·) D() 是个映射 (embedding) 网络,在 GAN 中扮演了 discriminator 的角色

  • Q Q Q I I I 一同输入 D D D 中获得一个整体的评分值

  • 梯度惩罚(gradient penalty)满足 Lipschitz 条件

  • 语义分割图和 RGB 的原输入一起作为总的输入送进 embedding 网络 D D D,这是一个全连接神经网络(fully convolutional neural network),有五个卷积层;两个自注意模块(self-attention module)被嵌入最后的三个层去捕获结构信息(structured information);这样的一个鉴别器 discriminator 可以产生整体映射(holistic embedding)

3.2 Optimization(优化)

  • 整个目标函数有:
    • 分类的交叉熵损失(multi-class cross_entropy) l m c ( S ) l_{mc}(S) lmc(S)
    • 用来做 pixel-wise 的 l p i ( S ) l_{pi}(S) lpi(S)
    • 用来做结构化蒸馏的损失(包括用来做 pair-wise 蒸馏的损失 l p a ( S ) l_{pa}(S) lpa(S)、用来做 holistic 蒸馏的损失 l h o ( S , D ) l_{ho}(S,D) lho(S,D)

l ( S , D ) = l m c ( S ) + λ 1 ( l p i ( S ) + l p a ( S ) ) − λ 2 ( l h o ( S , D ) ) l(S,D)=l_{mc}(S)+\lambda_1(l_{pi}(S)+l_{pa}(S) )-\lambda_2(l_{ho}(S,D)) l(S,D)=lmc(S)+λ1(lpi(S)+lpa(S))λ2(lho(S,D))

  • λ 1 \lambda_1 λ1 λ 2 \lambda_2 λ2 分别设置为 10 和 0.1,使得损失的范围在同一个可比较的量级;我们通过最小化目标函数来优化和提升语义分割网络 S S S,这要求我们最大化鉴别器 D D D 的损失;可以通过迭代下面两个步骤来实现

    • Train the discriminator(训练鉴别器) D 训练鉴别器相当于最小化 l h o ( S , D ) l_{ho}(S,D) lho(S,D) D D D 的目标是对来自 teacher 网络的 real sample(真实样本)给出高的分数,对来自student 网络的 fake sample (假样本)给出低的分数。

    • Train the compact segmentation network S (训练紧凑的语义分割网络 S) 给定义一个鉴别器网络 (discriminator network),目标是最小化紧凑语义分割网络的多分类交叉熵(multi-class cross-entropy)损失以及知识蒸馏损失(distillation loss)

      l m c ( S ) + λ 1 ( l p i ( S ) + l p a ( S ) ) − λ 2 l h o s ( S ) l_{mc}(S)+\lambda_1(l_{pi}(S)+l_{pa}(S))-\lambda_2l_{ho}^s(S) lmc(S)+λ1(lpi(S)+lpa(S))λ2lhos(S)

      • 这里的 l h o s ( S ) = E Q s ∼ p s ( Q s ) [ D ( Q s ∣ I ) ] l_{ho}^s(S)=\mathbb E_{Q^s\sim p_s(Q^s)}[D(Q^s|I)] lhos(S)=EQsps(Qs)[D(QsI)],是公式(3)中的一部分,我们期望 S S S 可以在鉴别器 D D D 的评估下达到一个较高的分数

你可能感兴趣的:(Knowledge,Distillation,类别论文阅读)