网络轻量化 - 知识蒸馏(knowledge distillation)

原文:《Distilling the Knowledge in a Neural Network》

目录

  • 前期知识
    • 集成模型(Ensemble Models)
      • Bagging
      • Boosting
      • 缺点
    • 知识蒸馏思想
  • 算法部分
    • 知识蒸馏方法
      • 引入温度参数 T(Temperature)
      • 组合两种 Loss

前期知识

集成模型(Ensemble Models)

通过结合了来自多个模型的决策,以提高最终模型的稳定性和准确性。

网络轻量化 - 知识蒸馏(knowledge distillation)_第1张图片 网络轻量化 - 知识蒸馏(knowledge distillation)_第2张图片

Bagging

  • 从原始样本抽取训练集:每轮从原始样本集抽取 n 个样本,共进行 k 轮抽取,获得 k 个训练集
  • 每次使用一个训练集获得一个模型,共得到 k 个模型
  • 对 k 个模型的预测结果进行组合(例如投票法),得到最终的预测结果

网络轻量化 - 知识蒸馏(knowledge distillation)_第3张图片

Boosting

  • 每次训练使用的是全部样本
  • 减小上一轮训练正确样本的权重,增大错误样本的权重

缺点

  • 模型可解释性差
  • 算力消耗大,运行时间长,不符合移动端需求
  • 模型的选择有随机性,不能确保是最佳组合

知识蒸馏思想

主要思想:

  • 训练一个小的网络,即 student
  • student 模仿预先训练好的大型网络,即 teacher

如何理解 knowledge?
**答:**训练好的模型参数信息保留了模型学到的知识。例如:
[cow, dog, cat, car]
Labels: [0, 1, 0, 0]
Predictions: [0.05, 0.3, 0.2, 0.005]
其中,模型在分类过程中将其分类成 cow 的概率是 car 的十倍,这就是文中所描述的 knowledge,需要从 teacher 网络向 student 网络蒸馏。

算法部分

知识蒸馏方法

  • 训练好的 teacher 网络输出一个类别的概率分布
  • student 网络输出一个类别的概率分布
  • 设计学生网络的 Loss,minimize 以上两个概率分布的差距

引入温度参数 T(Temperature)

teacher 网络的输出结果,其正确分类的概率值非常大,而其他类别的概率值接近于 0,这种结果会忽视其他类别包含的有用信息。
所以引入温度参数 T,可以蒸馏出更丰富的信息,如下图所示:

网络轻量化 - 知识蒸馏(knowledge distillation)_第4张图片

组合两种 Loss

网络轻量化 - 知识蒸馏(knowledge distillation)_第5张图片
具体 Loss Function 如下图:

网络轻量化 - 知识蒸馏(knowledge distillation)_第6张图片 网络轻量化 - 知识蒸馏(knowledge distillation)_第7张图片

注:① 给蒸馏损失较大权重,学生网络损失较小权重时,可以获得更好的结果;
② soft targets 产生的梯度大小会被缩放至 1 / T 2 1/T^2 1/T2,因此在同时使用 soft targets 和 hard targets 时,要将蒸馏损失项乘以 T 2 T^2 T2

你可能感兴趣的:(网络,机器学习,深度学习,神经网络)