bert 模型压缩原理

1. 压缩目的:

在基本不影响模型效果的基础上,对bert模型进行同构压缩,将layer 与embedding size减少, 尽可能提升模型的性能。

比较经典的压缩尺寸是 12 * 768 -> 6 * 384

下面以classifier task为例子,讲一下bert模型压缩的原理和实现.

classifier task的model的 结构:

 BERT --> MLP -->cross_entropy_loss

2. 基本概念

teacher model: 尺寸较大的模型, finetune model

student model: 尺寸较小的模型,target model

3. distillation loss的设计

distillation可以分为两步。第一步,使用classifier task的label 训练teacher model,如果要做的精确一点,可同时训练student model的classifier 以及teacher的sequence attention 的logits和student 的sequence attention logits做交叉熵.


loss1 -> grad -> loss2 -> grad -> loss3->grad

第二步,将teacher model 的 parameters 做冻结,detach(), 使用MSE Loss的方式修正student model的Mlp logits的结果


总结:第一步,主要实现teacher model的finetune和提高student的BERT layer与teacher BERT layer的sequence结果相关性

第二步:实现student MLP logits 与teacher MLP logits 的相关性.

实验证明可以基本实现在效果减小很少的情况下,性能有很大提升。

第一步的具体的流程可表示为:

1. teacher_sequence = teacher_sequence.detach() 做梯度冻结

  teacher_attention = torch.matmul(teacher_sequence , teacher_sequence.permute(0,2,1))

  input_mask = torch.unsqueeze(input_mask, 0) * torch.unsqueeze(input_mask, 1)

 将input_mask 也变成batch size * sequence * sequence的序列组合的形式.

teacher_att = torch.log_softmax(teacher_attention) * input_mask [使用input_mask将原序列中需要编码忽略的部分置0, 必要的时候softmax前可以将相应的mask掉的部分的值调低)

对student_sequence 采用同样的操作.

att_loss = teacher_att * torch.log(student_att)/(torch.sum(input_mask))

第二步的具体流程可表示为:

teacher_logits = teacher_logits.detach()

mse_loss = nn.MSE()(student_logits, teacher_logits)

你可能感兴趣的:(bert 模型压缩原理)