我将看过的增量学习论文建了一个github库,方便各位阅读地址
论文提出了一种算法,以解决增量学习中的灾难性遗忘问题,与iCaRL将特征提取器的学习与分类器分开不同,本论文提出的算法通过引入新定义的loss以及finetuning过程,在有效抵抗灾难性遗忘的前提下,允许特征提取器与分类器同时学习。
本论文提出的方法需要 e x a m p l a r examplar examplar
训练数据由新类别数据与examplar构成。
设有 n n n个旧类别, m m m个新类别,每个训练数据都有两个标签,第 i i i个训练数据的标签为
模型可以选用常见的CNN网络,例如ResNet32等,按照国际惯例,这一节会介绍distillation loss,作为一篇被顶会接收的论文,自然不能免俗
符号约定
符号名 | 含义 |
---|---|
N N N | 有 N N N个训练数据 |
p i p_i pi | 含义查看上一节 |
q i q_i qi | 含义查看上一节 |
q ^ i \hat q_i q^i | 新模型旧类别分支的输出,为一个 1 ∗ n 1*n 1∗n的向量 |
n n n | 旧类别分支 |
m m m | 新类别分支 |
o i o_i oi | 新模型对于第 i i i个数据的输出,为一个 ( n + m ) ∗ 1 (n+m)*1 (n+m)∗1的向量 |
Classification loss即交叉熵,如下:
L C ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n + m p i j ∗ s o f t m a x ( o i j ) L_C(w)=-\frac{1}{N}\sum_{i=1}^N\sum_{j=1}^{n+m}p_{ij}*softmax(o_{ij}) LC(w)=−N1i=1∑Nj=1∑n+mpij∗softmax(oij)
其中
s o f t m a x ( o i j ) = e o i j ∑ j = 1 n + m e o i j softmax(o_{ij})=\frac{e^{o_{ij}}}{\sum_{j=1}^{n+m}e^{o_{ij}}} softmax(oij)=∑j=1n+meoijeoij
distillation loss的形式如下
L D ( w ) = − 1 N ∑ i = 1 N ∑ j = 1 n p d i s t i j q d i s t i j L_D(w)=-\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{n}pdist_{ij}qdist_{ij} LD(w)=−N1i=1∑Nj=1∑npdistijqdistij
其中
p d i s t i j = e q ^ i j t ∑ j = 1 n e q ^ i j t q d i s t i j = e q i j t ∑ j = 1 n e q i j t pdist_{ij}=\frac{e^{\frac{\hat q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{\hat q_{ij}}{t}}}\\ qdist_{ij}=\frac{e^{\frac{q_{ij}}{t}}}{\sum_{j=1}^{n}e^{\frac{q_{ij}}{t}}} pdistij=∑j=1netq^ijetq^ijqdistij=∑j=1netqijetqij
L D ( w ) L_D(w) LD(w)即让模型尽可能的记住旧类别的输出分布。t是一个超参数,在本论文中, t = 2 t=2 t=2
个人疑问
distillation loss的作用是让模型记住以往学习到的规律,相当于侧面引入了旧数据集,从而抵抗类别遗忘。
直觉上来说,distillation loss应该只对旧类别数据进行计算,但是新类别数据的旧类别分支输出仍用于计算distillation loss,论文对此给出的解释是“To reinforce the old knowledge”
我认为这种做法的出发点为:旧模型对于新类别数据的输出(经softmax处理),也是一种旧知识,也需要防止遗忘,因此,新模型对于新类别数据的旧类别输出(经softmax处理),与旧模型对于新类别数据的输出(经softmax处理)也要尽可能一致
使用herding selection算法,从新类别数据中抽取部分数据,构成与旧类别examplar大小相等的数据集,此时各类别数据之间类别平衡,利用该数据集,在小学习率下对模型进行微调,选用的loss函数应该是交叉熵。
步骤二使用类别不平衡的数据训练模型,会导致分类器出现分类偏好,finetuning可以在一定程度上矫正分类器的分类偏好
论文给出了两类方法
使用herding selection算法选择新类别数据,构成新类别的 e x a m p l a r examplar examplar
论文训练模型使用了数据增强,具体方式如下:
每个实验都进行了五次训练,取平均准确率
实验过程没有太多有趣的地方,在此不做过多说明
在CIFAR100上的结果如下
img/cls表示每个examplar中图片的个数
首先是选择数据构建examplar的方法,论文比对了三类方法
上述三个选择方法的解释如下:
接下来论文比对了算法各部分对准确率提升的贡献
上述模型的解释如下
类别不平衡会导致灾难性遗忘,模型在学习旧类别时,所使用的数据是充分的,引入知识蒸馏loss,就是尽可能保留旧数据上学习到的规律,在训练时,相当于侧面引入了旧数据。
论文在distillation loss的基础上又引入了类别平衡条件下的finetuning,相当于进一步抵抗增量学习下类别不平衡的导致的分类器偏好问题,由此取得模型性能的提升。