知识蒸馏7:知识蒸馏代码详解

代码说明

与yolov5-v6.1代码的区别:

  • data/hyps/hyp.scratch-low-distillation.yaml(添加的文件,修改hyp.scratch-low.yaml得到)
  • utils/loss.py 添加一个函数compute_distillation_output_loss
  • train_distillation.py( 添加的文件,修改train.py得到)

hyp.scratch-low-distillation.yaml

知识蒸馏7:知识蒸馏代码详解_第1张图片

  • 该文件相对于原来的hyp.scratch-low.yaml,多了dist超参数 ,可以在[0,1]范围内调整,接近1的话网络会更重视蒸馏损失,靠近0的话就更倾向于detection损失,该草参数用来平衡detection 损失和蒸馏损失。
  • 对于box loss ,clss loss,obj loss也有对应的加权损失。

utils/loss.py

utils/loss.py 添加一个函数compute_distillation_output_loss

你可能感兴趣的:(模型轻量化,深度学习,人工智能,python)