Learning without Forgetting 详解(LwF)

一、概述

这篇文章仍然从最简单的分类任务入手,巧妙运用了Knowledge Distill技术来缓解这样的问题。作者的oral。具体地,使用旧模型作为teacher model,对于新任务中的每一个样本。和传统的Finetune比起来,LwF使用了teacher model输出的soften softmax对新任务中的样本进行约束。
Learning without Forgetting 详解(LwF)_第1张图片

二、网络结构详解

对于一个CNN网络,对于shared参数 θ s θ_s θs,和用于做特殊任务的参数 θ o θ_o θo。我们的目的就是增加一个新任务的参数 θ n θ_n θn,并让他学习到的参数能够在新的和旧的任务上都有很好的表现(见图e)。
算法流程大致如下:
Learning without Forgetting 详解(LwF)_第2张图片
训练过程记录如下:
1)记录新的数据在原始的网络上的outputs(defined by θ s θ_s θsand θ o θ_o θo),对于新增的类,我们增加相应的FC的节点个数,兵随机初始化权重 θ n θ_n θn
2)我们训练网络并优化其loss在所有的分类上有最小的loss。在训练的时候,我们首先freeze掉 θ s θ_s θs θ o θ_o θo,然后训练 θ n θ_n θn指导其收敛,然后我们在训练所有的 θ s θ_s θs θ o θ_o θo θ n θ_n θn指导其收敛。

三、损失函数

在这里插入图片描述
ps:
y_hat: the softmax output of the network;
yn: the one-hot ground truth label vector.
在这个新的网络中,我们希望对于原来的任务其输出能和原来的网络的输出接近,所以我们采用蒸馏loss:
Learning without Forgetting 详解(LwF)_第3张图片
Learning without Forgetting 详解(LwF)_第4张图片
Learning without Forgetting 详解(LwF)_第5张图片

我们采用T=2进行蒸馏。

四、蒸馏的简单解释

神经网络模型在预测最终的分类结果时,往往是通过softmax函数产生概率分布的:
Learning without Forgetting 详解(LwF)_第6张图片
这里将T定义为温度参数,是一个超参数,q_i是i类的概率值大小。

比如一个大规模网络,如ImageNet这样的大网络,能够预测上千种类别,正确类别的概率值能够达到0.9,错误类的概率值可能分布在10-8~10-3这个区间中。虽然每个错误类别的的概率值都很小,但是10-3还是比10-8高了五个数量级,这也反映了数据之间的相似性。

比如一只狗,在猫这个类别下的概率值可能是0.001,而在汽车这个类别下的概率值可能就只有0.0000001不到,这能够反映狗和猫比狗和汽车更为相似,这就是大规模神经网络能够得到的更为丰富的数据结构间的相似信息。
由于大规模神经网络在训练的时候虽然是通过0-1编码来训练的,由于最后一层往往使用softmax层来产生概率分布,所以这个概率分布其实是一个比原来的0-1 编码硬目标(hard target)更软的软目标(soft target)。这个分布是由很多(0,1)之间的数值组成的。

同一个样本,用在大规模神经网络上产生的软目标来训练一个小的网络时,因为并不是直接标注的一个硬目标,学习起来会更快收敛。

更巧妙的是,这个样本我们甚至可以使用无标注的数据来训练小网络,因为大的神经网络将数据结构信息学习保存起来,小网络就可以直接从得到的soft target中来获得知识。

这个做法类似学习了样本空间嵌入(embedding)信息,从而利用空间嵌入信息学习新的网络。
Learning without Forgetting 详解(LwF)_第7张图片

因此:

1、首先用较大的T值来训练模型,这时候复杂的神经网络能够产生更均匀分布的软目标;

2、之后小规模的神经网络用相同的T值来学习由大规模神经产生的软目标,接近这个软目标从而学习到数据的结构分布特征;

3、最后在实际应用中,将T值恢复到1,让类别概率偏向正确类别。

所以,蒸馏神经网络取名为蒸馏(Distill),其实是一个非常形象的过程。

我们把数据结构信息和数据本身当作一个混合物,分布信息通过概率分布被分离出来。首先,T值很大,相当于用很高的温度将关键的分布信息从原有的数据中分离,之后在同样的温度下用新模型融合蒸馏出来的数据分布,最后恢复温度,让两者充分融合。这也可以看成Prof. Hinton将这一个迁移学习过程命名为蒸馏的原因。

文章部分摘自博客。

你可能感兴趣的:(蒸馏,增量学习,深度学习,多任务)