BERT等复杂深度学习模型加速推理方法——模型蒸馏

参考《Distilling the Knowledge in a Neural Network》Hinton等

蒸馏的作用

首先,什么是蒸馏,可以做什么?

正常来说,越复杂的深度学习网络,例如大名鼎鼎的BERT,其拟合效果越好,但伴随着推理(预测)速度越慢的问题。

此时,模型蒸馏就派上用场了,其目的就是为了在尽量减少模型精度的损失的前提下,大大的提升模型的推理速度。

实现方法

其实,模型蒸馏的思想很简单。

第一步,训练好原本的复杂网络模型,如BERT,我们称为Teacher模型

第二步,用一个较为简单的模型去拟合Teacher模型,称为Student模型

最后,利用训练好的Student模型进行推理预测。

细节要点

Student模型如何拟合Teacher模型?

首先,我们要理解soft targets(软标签)和hard targets(硬标签)。

soft targets:属于不同标签的概率,一般是softmax的计算结果。例如,我们的标签是“男”和“女”两种,那么软标签的形式应该就是(0.4,0.6)。

hard targets:属于哪一种标签。例如,该样本的标签是“男”,那么硬标签就是(1,0)。

在模型蒸馏中,Student模型应该是去拟合Teacher模型推理的soft targets,因为soft targets包含的信息更多。

那么,Student模型的Loss为:

L o s s = C r o s s e n t r o p y ( s , t ) Loss=Cross entropy(s, t) Loss=Crossentropy(s,t)

s表示Student模型推理的soft targets,t表示Teacher模型推理的soft targets

补充

有些人认为把hard targets也加进去Student的Loss中,即:

L o s s = C r o s s e n t r o p y ( s , t ) + C r o s s e n t r o p y ( s , y ) Loss=Cross entropy(s, t)+Cross entropy(s, y) Loss=Crossentropy(s,t)+Crossentropy(s,y)

y表示样本的真实标签(hard targets)。

Student模型

引用自美团技术团队:

我们使用IDCNN-CRF来近似BERT实体识别模型,IDCNN(Iterated Dilated CNN)是一种多层CNN网络,其中低层卷积使用普通卷积操作,通过滑动窗口圈定的位置进行加权求和得到卷积结果,此时滑动窗口圈定的各个位置的距离间隔等于1。高层卷积使用膨胀卷积(Atrous Convolution)操作,滑动窗口圈定的各个位置的距离间隔等于d(d>1)。通过在高层使用膨胀卷积可以减少卷积计算量,同时在序列依赖计算上也不会有损失。在文本挖掘中,IDCNN常用于对LSTM进行替换。实验结果表明,相较于原始BERT模型,在没有明显精度损失的前提下,蒸馏模型的在线预测速度有数十倍的提升。

蒸馏训练

第一种方法:

在离线的情况下,将Teacher模型对所有样本的推理结果存入磁盘中,然后Student模型从磁盘中读取样本及Teacher模型推理的软标签,进行模型训练。

第二种方法:

将Teacher模型和Student模型同时加载到网络中,但将Teacher模型冻结,只进行前向传播,不进行反向传播更新参数;然后将前向传播的结果传递给Student模型的Loss中,训练Student模型。

第三种方法:

方法2存在这样的缺点:每一个batch,Student模型都需要等待Teacher模型推理结束才能进行反向传播,影响训练速度。

那我们就可以将Teacher模型和Student模型分开部署,进行异步计算。Teacher模型只需要前向传播,而Student模型需要前向和反向传播,在处理时间上可能处于一个持平的状态。

你可能感兴趣的:(深度学习,自然语言处理,深度学习,BERT,模型蒸馏,推理加速)