知识蒸馏(尝试在ASR方向下WeNet中实现--代码)

知识蒸馏(尝试在WeNet中实现知识蒸馏)

  • 1、知识蒸馏简介
    • 1.1、论文
    • 1.2、目标蒸馏结构解释
  • 2、WeNet中关于知识蒸馏的思考
    • 2.1、WeNet结构
    • 2.2、CTC端的蒸馏
  • 3、WeNet知识蒸馏实验
  • 4、总结

1、知识蒸馏简介

蒸馏的作用:模型压缩、加速训练
蒸馏的知识:大模型logits层、激活、神经元、中间层特征、教师网络参数等
蒸馏的结构:大的教师模型、小的学生模型(大小什么意思?可以理解为我们的层数少一些,网络结构少一些)

1.1、论文

论文:《Distilling the Knowledge in a Neural Network》
详细信息可自己查看该论文。论文中有提及ASR,因此本文主要是实现ASR中的知识蒸馏。

论文综述:《Knowledge Distillation A Survey》
可查看知识蒸馏的基础知识与常用蒸馏方式。

1.2、目标蒸馏结构解释

本文主要进行logits蒸馏(还可以有基于特征、基于关系的蒸馏方式)。其实主要的就是计算蒸馏中的两个loss,soft loss和hard loss。

  • Soft Loss
    softoutput1(softtarget):教师模型,经过教师模型后的输出除以温度T后再经过softmax层后得到softoutput1。
    softoutput2:学生模型,经过学生模型后的输出除以温度T后再经过softmax层后得到softoutput2。
    softoutput1与softoutput2求解Soft Loss(这里softloss我使用了KL相对熵loss)
  • Hard Loss
    hardoutput:学生模型,经过学生模型后的输出再经过softmax层后得到softoutput。
    hardtarget:真实的实验数据
    hardoutput与hardtarget求解Hard Loss(这里就是正常不进行蒸馏的训练)

知识蒸馏(尝试在ASR方向下WeNet中实现--代码)_第1张图片

2、WeNet中关于知识蒸馏的思考

2.1、WeNet结构

如图可以看到,wenet包括一个贡共享的编码器和两个解码器结构,左解码器为CTC,右解码器为Attention。因此我们在进行蒸馏的时候应该考虑到他的两个输出,即CTC端的输出,以及Attention端的输出。
知识蒸馏(尝试在ASR方向下WeNet中实现--代码)_第2张图片

2.2、CTC端的蒸馏

在Wenet中进行蒸馏,这里考虑对Attention端直接进行蒸馏,即对输出直接进行loss求解。对CTC端分为不同级别的蒸馏,具体如下:

  • frame
    正常训练中我们对数据的处理即处理到帧级别,因此wenet中进行frame级别的蒸馏也就是对encoder(教师模型和学生模型)的输出除以温度T,求解Soft Loss,这里的Soft Loss使用pytorch自带的nn.KLDivLoss.

这里过两天我debug之后再写哈哈哈

  • sequence_sample
  • sequence_max

3、WeNet知识蒸馏实验

实验:

教师模型:已经训练完成的大模型。

学生模型:教师模型的一半(什么意思呢?就是教师模型的编码器有12层,解码器有6层,做卷积的时候卷积核为31,那我学生模型的编码器就设置6层,解码器有3层,做卷积的时候卷积核为15)

温度:有试过1、1.2、2,发现温度为1时控制其他变量后效果更好,因此温度T=1,其他配置一致。

代码的话过两天我debug之后再写哈哈哈

4、总结

知识蒸馏在wenet中的优缺点(基于3中的实验而言)

优点

  • 模型大小:从513M压缩到85M(相较于量化而言压缩得更小)。
  • 语音wer:测试过包括aishell、magicdata、aidatatang、wenetspeech等测试集,wer基本一致,且在部分解码方式下知识蒸馏后的模型效果更好(相较于量化而言wer也更小)。
  • 训练速度:训练1wh数据,4张3090卡,batchsize为18,一个epoch能够从9小时变为8.5小时左右。

缺点:

  • 实现过程:相比于wenet自带的量化功能而言,实现知识蒸馏比较麻烦。且对语音识别过程进行知识蒸馏而言,现在研究感觉还比较少,处于试错的阶段(比如对温度T的把握,就要多做尝试)

你可能感兴趣的:(人工智能,语音识别)