【知识蒸馏】|Distilling the Knowledge in a Neural Network

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。

hard target:
hard target其实就是数据的标签,比如分类驴和马,识别马时,hard target只有马的index处的值为1,驴的标签则为0.

soft target:
soft target是来自于teacher model的检测结果。hard target包含的信息熵很低,同时,hard target容易引起过拟合。可以这么理解,假设model 推测该类别的置信度为0.6,而label是1,那么相较于软化后的label如0.8,硬标签计算出的loss就会更大,使得bp的幅度变大,这样容易引起模型对某些特征的偏好,从而降低泛化能力。为了让student model学习更多有价值的信息,作者软化了标签,如下式,增加温度参数T软化了label。

在这里插入图片描述
loss是两者的结合。代码如下,soft target部分是KL散度,hard target是交叉熵。

def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y/T, dim=1)
    q = F.softmax(teacher_scores/T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)
    return l_kl * alpha + l_ce * (1. - alpha)

ref

https://aistudio.csdn.net/62e38a4acd38997446774bb1.html?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Eactivity-2-80568658-blog-103056135.pc_relevant_aa&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Eactivity-2-80568658-blog-103056135.pc_relevant_aa&utm_relevant_index=2

你可能感兴趣的:(知识蒸馏,python,深度学习,开发语言)