知识蒸馏一个重要目的是让学生模型学习到老师模型的泛化能力,让轻量级的学生模型也可以具备重量级老师模型的几乎同样的能力。
一个很高效的蒸馏方法就是使用老师网络softmax层输出的类别概率来作为软标签,和学生网络的softmax输出做交叉熵。
传统训练方法是硬标签,正类是1,其他所有负类都是0。但知识蒸馏的训练过程过程是用老师模型的类别概率作为软标签。
大模型虽然效果很好,但模型较重推理速度太慢无法瞒足工业要求,而小模型轻,推理速度快,但是直接使用数据训练效果较差,知识蒸馏就是想让小模型在拥有较快的推理速度下,也具备大模型的泛化能力。
原始的softmax:
上述有说到,知识蒸馏是student模型学习tearch模型的软标签,但是如果
根据公式可以看出,让T越大时,softmax输出值越平滑,输出值得熵越大,会放大负标签携带的信息,模型会相对校对的关注负标签,能够充分的学习。 一般来讲T会大于1;
说白了温度的高低改变的是学生网络对负标签的关注程度
实际上,负类别中包含一定的信息,尤其是那些概率值较高的负类别。 但由于老师网络的负类别可能会有噪声,并且负类别的概率值越低,其信息就越不可靠。因此温度的选取比较看经验,本质上就是在下面两件事之中取舍
总的来说,温度的选择和学生网络的大小有关,学生网络参数量比较小的时候,相对比较低的温度就可以了,因为参数量小的模型不能捕获所有知识,所以可以适当忽略掉一些负标签的信息。
distill loss(对应软标签) :
是老师模型softmax经过高温后输出的概率分布和学生网络在同等温度下的概率分布做交叉熵, 软标签 loss:
student loss(对应硬标签) :
是学生网络在温度为1下的概率分布和真实标签做交叉熵,硬标签 loss:
- 这里我使用了一种soft-label的方法。是将teacher模型的logits表示经过softmax后,与one-hot表征的实际label进行相加, 注意这里引入相加的权重alpha,实验做好的值为0.5,alpha越大越依赖教师模型的logits。然后得到一个新的label表示。如teacher-logit-softmax = [0.2,0.7,0.1],实际标签one-hot = [0,1,0],alpha = 0.5,那么最后的label = [0.1,0.85,0.05],然后用这个label和student计算获得的logits进行MSELoss计算,求导。这种方法获得的结果和直接用MSELoss计算后,然后使用alpha权重相加结果类似,但好处是少了一次MSELoss的计算过程,在训练时,训练速度更快。
因为是中文NLP任务,对于teacher模型选择的标准是,尽量好,尽量优秀,甚至可以使用集成学习的方法获得最优结果。
项目中使用了中文的Roberta-base模型作为teacher模型(已经对下游NER实体进行了Finetune,精度F1 = 94.67%),具体参数:Epoch=3,max_sequence_length = 256,batch_size = 32,model_size = 42.2M。
前后使用多种Student模型,选取的条件是速度满足当前模型预测的速度要求,(不做蒸馏前,纯训练)精度越高越好。
将BERT模型蒸馏至TextCNN 和BiLSTM等小模型上,精度下降3%,速度提升400倍。注意文章使用了word2vec词向量,并非完全从头训练,具体细节可看论文和代码。
蒸馏学习的Student模型分别使用了ALBert-Base,ALBert-Tiny和ELECTRA-Small这三种模型,模型使用的alpha = 0.5,使用的是MSELoss的方法,具体结果:
Model | Base-f1 | Distill-f1 | M-Size | I-Time |
---|---|---|---|---|
Roberta-Base | 0.94675 | — | 412M | 95.35s |
AlBert-base | 0.810 | 0.909 | 42.2M | 27.3 |
ALBert-Tiny | 0.612 | 0.824 | 16.3M | 9.3s |
ELECTRA-small | 0.9124 | 0.9267 | 49.4M | 24.8s |
使用训练好的teacher模型对数据打标,形成伪标签,再训练student模型,即使部分case teacher没有标对,也没有很大的关系,目的就是让student更像teacher,本来badcase就很小,对训练影响度有限,但是伪标签数据不易过度,以免真正影响效果。
如上述实验中的奖RoBerta-Base模型内容蒸馏到ALBert-Tiny,模型的size差异大约在30倍,如果直接蒸馏,效果会不好。精度大约只能达到82.4%。这里可以借鉴miniLM(Ref-9)的一种操作Trick,间接蒸馏。具体做法是先将大模型(如:RoBerta-Base,94.7%)里的知识蒸馏到一个中(过渡)模型(如:ELECTRA-base,92.1%),然后再用中模型作为teacher,将知识蒸馏到真正的小模型(如此处的ALBert-Tiny),模型精度最终可以达到88.3%,精度大约有6个点的提升。
温度T,高温T。通常模型训练的时候使用高温T,而在模型测试和预测阶段的时候,是不使用teacher模型的,仅使用student模型进行测试和预测,也就是T在预测阶段不使用。
MSELoss计算L-soft。不使用上述复杂的L-soft,而使用简单的均方差损失函数——MSELoss。
Hard-Loss加入模型。即便使用了Soflt-Loss,还是需要引入Hard-Loss以及超参数 目的是teacher模型也可能存在无法完全学对的可能,所以在数据质量有保证的情况下,引入学生模型的hard-loss能更好的学会teacher无法学会的知识。实际使用过程中也发现,引入hard-loss很有效果。
使用更多损失函数。,Hinton的蒸馏学习使用的是Cross-Entorpy作为损失函数,其实损失函数不止于交叉熵损失函数,包括MSELoss,NLLLoss,HingeLoss等。实验中,我使用了MSELoss和CELoss做比较,发现二者对于Student模型的效果类似,所以对于不同下游任务,可以使用更贴合的Loss函数,不必局限于CELoss。但是对于蒸馏学习的理解一定要到位,才能更合理的利用Loss。
集成学习加入到蒸馏学习中。通常我们不会仅仅使用一个老师,而是使用多个teacher,然后将多个teacher的知识权重相加引入到student模型中。这是将集成学习和蒸馏学习相融合,能让学生学到更多信息,但是也同时增加了模型训练的难度(增多了超参数的数量以及集成学习方法的比较),对于初学者不建议使用。
因为老师网络也有一定的错误率,使用ground truth可以有效降低错误被传播给学生网络的可能。
举例,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。
这是一个经验的结论。一个可能的原因是,由于soft target产生的gradient与hard target产生的gradient之间有与 T 相关的比值。
在同时使用soft target和hard target的时候,需要在soft target之前乘上 T 2 T^2 T2 这个系数,这样才能保证soft target和hard target贡献的梯度量基本一致。
直接match logits指的是,直接使用softmax层的输入logits(而不是输出)作为soft targets,需要最小化的目标函数是Net-T和Net-S的logits之间的平方差。直接上结论: 直接match logits的做法是 T→∞ 的情况下的特殊情形。