一. Seq2seq模型蒸馏方法总体过程如下
1. 训练teacher模型
2. 产生student模型
3. 利用teacher模型预测的logits和来自语料的true labels来计算student 模型的训练过程中的loss。
二. 涉及的具体步骤和参数有
1. 训练参数量相对较大的teacher模型。
2. 生成student模型,可以从teacher模型结构中抽取部分层组成,也可以随机初始化student模型的参数。
如果从teacher模型中抽取,则可以在训练时固定某些层,例如可在训练时freeze_embeds.
如果student模型的encoder和teacher模型的encoder完全一致,在训练时,可以考虑freeze_encoder。其他情况则不考虑freeze_encoder。
3. 根据teacher logits产生的时间不同,模型蒸馏可分为在线蒸馏和离线蒸馏。
离线蒸馏是采用teacher模型,预先将decoder端每个token对应的词表(或类别)大小的概率分布预测出来,在训练时和true label一起输入来计算loss。
在线蒸馏是同时将teacher模型和student模型加载到训练机上,在训练时利用teacher模型来预测每个token位置的概率分布(logits), 同时和true label一起参与loss的计算。
在线蒸馏时,teacher模型参数固定,只有student模型的参数为trainable状态。
三. 关于loss的计算
1. Loss共有3部分构成,即来自teacher_logits的loss_ce, 来自true_labels的loss_mlm, 和来自中间层的loss_hid.
对应3个loss部分在总的loss中的比例系数可以分别用alpha_ce, alpha_mlm, 和alpha_hid表示。因此总的loss可以表示为:
loss_total = (alpha_ce * loss_ce) + (alpha_mlm * loss_mlm) + (alpha_hid * loss_hid)
其中,
loss_ce = distill_loss_fn(student_logits, teacher_logits,temperature)
loss_mlm = loss_fn(student_logits, true_labels)
loss_hid = mse_loss(student_hid, teacher_hid).
2. 关于loss_hid可以这样理解,采用teacher中的某些层来监督student中的各层的结果。例如采用一个12层的teacher模型,来蒸馏一个3层的student模型,如果只关注encoder端,可以用teacher_encoder [0, 6, 11]层来分别监督student_encoder [0, 1, 2]层的训练结果。
如果是离线蒸馏,并且需要在loss中计算student各层的损失,则在需要将teacher模型各层的结果,和teacher logits一起预先计算并保存。
3. loss中涉及的3个部分的损失函数不同,其中mlm对应的是一般的cross_entropy, hid对应的为mse,ce部分对应的为和温度相关的KLDivLoss, loss_ce具体可以描述为:
loss_ce = KLDivLoss(
softmax(student_logits/temperature, dim=-1), # vocab_size
softmax(teacher_logits/temperature, dim=-1)
) * (temperature ^ 2)
关于最后需要乘温度的平方,可以阅读【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 - 知乎,简单表述为,loss_ce乘上 temperature^2 后,与loss_mlm的值相当,因此为了平衡loss_ce对损失的贡献,要乘temperature^2 。
4. 机器翻译(生成式)中使用离线蒸馏的问题
尝试将100句中英语料的teacher logtis预测并保存,发现保存后的文件为869M(.npy格式).这很大原因是因为label的维度过大导致的,因为teacher logits的最后一个维度为词表大小,词表大小为5w左右(裁剪后的mbart50模型)。
考虑到机器翻译的语句对经常为千万级别,对teacher logtis的存储空间要求较高,因此离线蒸馏在现有方法改进之前,并不适用机器翻译。
5. 综合来看,蒸馏涉及的主要参数有
--teacher_model
--student_encoder_layers=3
--student_decoder_layers=3
--temperature=2
--alpha_ce=0.5
--alpha_mlm=0.5
--alpha_hid=0
--freeze_encoder=False
--freeze_embeds
--max_sentence_length=64
--train_batch_size
--train_epochs=5