bert蒸馏初探

目录:

  1. 指标结果
    • 指标
    • 折线图
  2. 小结
  3. 方案

1. 指标结果

数据一:电网数据

teacher-dev

                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.9004    0.8791    0.8827        58
   weighted avg     0.9206    0.9138    0.9111        58

student-dev

bert-layer1(约97Mb)
epoch-19
                 precision    recall  f1-score   support

       accuracy                         0.8103        58
      macro avg     0.6136    0.6035    0.5917        58
   weighted avg     0.8448    0.8103    0.8107        58

epoch-29
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.6305    0.6175    0.6118        58
   weighted avg     0.8661    0.8448    0.8468        58

epoch-39
                 precision    recall  f1-score   support

       accuracy                         0.8966        58
      macro avg     0.7522    0.6769    0.6964        58
   weighted avg     0.9321    0.8966    0.9060        58

epoch-49
                 precision    recall  f1-score   support

       accuracy                         0.8276        58
      macro avg     0.6204    0.6074    0.6011        58
   weighted avg     0.8508    0.8276    0.8296        58

epoch-59
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.6277    0.6144    0.6098        58
   weighted avg     0.8590    0.8448    0.8449        58

bert-layer3(约150Mb)
epoch  9
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.8619    0.8411    0.8358        58
   weighted avg     0.8542    0.8448    0.8395        58

epoch  19
                 precision    recall  f1-score   support

       accuracy                         0.9310        58
      macro avg     0.9198    0.9013    0.8969        58
   weighted avg     0.9387    0.9310    0.9299        58

epoch  29
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  39
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  49
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58

epoch  59
                 precision    recall  f1-score   support

       accuracy                         0.9138        58
      macro avg     0.8984    0.8791    0.8735        58
   weighted avg     0.9196    0.9138    0.9104        58
bert-layer3

修改optimizer为:torch.optim.SGD(student_model.parameters(), lr=0.05)AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

epoch  9
                 precision    recall  f1-score   support

       accuracy                         0.8103        58
      macro avg     0.7917    0.8240    0.7970        58
   weighted avg     0.8203    0.8103    0.8010        58

epoch  19
                 precision    recall  f1-score   support

       accuracy                         0.8621        58
      macro avg     0.8899    0.8443    0.8560        58
   weighted avg     0.8641    0.8621    0.8585        58

epoch  29
                 precision    recall  f1-score   support

       accuracy                         0.8793        58
      macro avg     0.9019    0.8529    0.8658        58
   weighted avg     0.8816    0.8793    0.8752        58

epoch  39
                 precision    recall  f1-score   support

       accuracy                         0.8793        58
      macro avg     0.8814    0.8529    0.8537        58
   weighted avg     0.8834    0.8793    0.8755        58

epoch  49
                 precision    recall  f1-score   support

       accuracy                         0.8621        58
      macro avg     0.8252    0.8443    0.8304        58
   weighted avg     0.8648    0.8621    0.8601        58

epoch  59
                 precision    recall  f1-score   support

       accuracy                         0.8448        58
      macro avg     0.8058    0.8358    0.8141        58
   weighted avg     0.8573    0.8448    0.8461        58


数据二:某品牌奶粉数据

teacher-dev

              precision    recall  f1-score   support

          测评     0.7800    0.7800    0.7800        50
          种草     0.8856    0.8754    0.8805       345
          科普     0.6636    0.6887    0.6759       106

    accuracy                         0.8263       501
   macro avg     0.7764    0.7813    0.7788       501
weighted avg     0.8281    0.8263    0.8272       501

student-dev

bert-layer-1
epoch 9
              precision    recall  f1-score   support

    accuracy                         0.6886       501
   macro avg     0.2295    0.3333    0.2719       501
weighted avg     0.4742    0.6886    0.5616       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.7265       501
   macro avg     0.6689    0.6261    0.6331       501
weighted avg     0.7444    0.7265    0.7295       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7106       501
   macro avg     0.6198    0.5992    0.6067       501
weighted avg     0.7152    0.7106    0.7118       501

epoch 38
              precision    recall  f1-score   support

    accuracy                         0.7146       501
   macro avg     0.6200    0.5898    0.6007       501
weighted avg     0.7161    0.7146    0.7139       501

epoch 49
              precision    recall  f1-score   support

    accuracy                         0.7146       501
   macro avg     0.6451    0.6156    0.6243       501
weighted avg     0.7269    0.7146    0.7182       501
bert-layer3
epoch 9
              precision    recall  f1-score   support

    accuracy                         0.7605       501
   macro avg     0.7092    0.6881    0.6920       501
weighted avg     0.7782    0.7605    0.7659       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.7764       501
   macro avg     0.7030    0.7069    0.7049       501
weighted avg     0.7787    0.7764    0.7775       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7764       501
   macro avg     0.6991    0.6977    0.6981       501
weighted avg     0.7791    0.7764    0.7776       501
bert-layer3

修改optimizer为:torch.optim.SGD(student_model.parameters(), lr=0.05)AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

epoch 9
              precision    recall  f1-score   support

    accuracy                         0.7685       501
   macro avg     0.6756    0.6662    0.6690       501
weighted avg     0.7734    0.7685    0.7701       501

epoch 19
              precision    recall  f1-score   support

    accuracy                         0.8064       501
   macro avg     0.7360    0.7261    0.7295       501
weighted avg     0.8106    0.8064    0.8078       501

epoch 29
              precision    recall  f1-score   support

    accuracy                         0.7784       501
   macro avg     0.7089    0.6968    0.6998       501
weighted avg     0.7870    0.7784    0.7812       501

奶粉数据上的student表现:

bert-layer-1

image.png

bert-layer-3(SGD)

image.png

bert-layer-3(Adamw)

image.png

2. 小结

  • 1层的layer没有3层的好使(废话
  • SGDAdamW没感觉到特别特别明显差异,先当作炼丹问题 | update:看下图的话感觉SGD相对更稳定一些
  • 训练过程很容易崩掉啊,后面降得跟啥似的
  • 两个loss得权重比例和Temperature感觉取值也很玄学,可炼
  • 目前student部分不够完善
  • 如果teacher结果好,1层的student表现还行;如果teacher表现不是非常理想,那student如果结构弱也比较吃亏

3. 方案

对于teacher模型,我在代码中返回的是return loss, dense_2_with_softmax, dense_2_output,其中dense_2_output即为logits,这个后面会用到。

对于student模型,与teacher的模型结构基本上完全一样,但是在bert_config里面有不同的设置,我在这里将num_attention_heads设置为3,将num_hidden_layers分别设置成1和3进行了尝试。

训练部分有用部分如下:

model = torch.load("/data/static_MODEL/event_extract/sentence_classify_daneng_teacher/fold_0/model_epoch_23_p_1.0000_r_1.0000_f_1.0000.pt")
student_model = Student(args.bert_model_toy, args.label_num)

这里model即为teacher,是直接从训练好的模型加载的,故设为*.eval()

optimizer也是用了两种进行尝试,分别是

# 第一种方法:teacher model即用的这个
param_optimizer = list(student_model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

# 第二种方法:从网上直接粘贴过来的
optimizer = torch.optim.SGD(student_model.parameters(), lr=0.05)

do_train的训练过程中,对于每个batch数据,进行:

with torch.no_grad():
    _, teacher_output_with_softmax, teacher_output = model(input_ids, segment_ids, input_mask, label_ids)
student_output, student_output_with_softmax = student_model(input_ids, segment_ids, input_mask, label_ids)

后面会用到student_outputteacher_output,实际上就是student去学习teacher的分布,对于论文比较常见的是:

image.png

在当前的实验中是摘抄了这段代码:

def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y/T, dim=1)
    q = F.softmax(teacher_scores/T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)

    return l_kl * alpha + l_ce * (1. - alpha)

# 调用
loss = distillation(student_output, teacher_output, label_ids.long(), T, 0.2)

你可能感兴趣的:(bert蒸馏初探)