使用知识蒸馏提升模型推理性能

目录

知识蒸馏介绍

Logits

Temperature

理论介绍

实验代码

实验结果


知识蒸馏介绍

首先,我们先简单地了解下知识蒸馏概念[2]。

通常,大模型可能是一个复杂的网络或多个网络的组合,表现出优越的效果和泛化能力。而小模型由于其较小的规模,其表达能力可能受到限制。为了提高小模型的效果,我们可以借助大模型所学习到的知识来指导小模型的训练。这样,小模型在参数数量明显减少的情况下,也能够达到与大模型相似的效果。这种策略就是知识蒸馏在模型压缩中的实践应用。

Geoffrey Hinton及其团队在论文Distilling the Knowledge in a Neural Network[3]中首次提出了“知识蒸馏”的思想,这是知识蒸馏中的开山之作,分量十足。其核心理念是首先训练一个大型复杂的网络,接着利用这个大网络的输出以及数据的真实标签来训练一个更轻量级的网络。在知识蒸馏的结构中,这个大型网络被称为“Teacher”模型,而轻量级网络则被称为“Student”模型。

现阶段,知识蒸馏已经有了长足的发展,方法繁多。常见的知识蒸馏可分为目标蒸馏特征蒸馏

  • • 目标蒸馏:Student模型只学习Teacher模型的Logits结果知识,一般为Soft Logits

  • • 特征蒸馏:Student模型学习Teacher网络结构中的中间层特征,利用Teacher模型的信息更加充分,训练难度更大

本文主要介绍目标蒸馏,一般的目标蒸馏模型结构[4]如下图所示:

 

使用知识蒸馏提升模型推理性能_第1张图片

图1:知识蒸馏模型结构

步骤如下:

  1. 1. 在原有训练数据集中训练好Teacher模型,一般为复杂的(大)模型;

  2. 2. 借助温度参数(Temperature,T)和Teacher模型的Logits结果,产生Soft Labels;

  3. 3. 在相同训练集上训练小模型(Student模型),最终loss为两部分loss的权重和:大模型的Soft Labels和小模型的Soft Labels(Predictions)的K-L loss;小模型在数据集上的交叉熵损失

  4. 4. 使用训练好的小模型做最终的模型推理(Inference)

让我们暂时脱离模型架构,来理解两个重要的概念:LogitsTemperature.

Logits

Logits指的是在分类模型结构中,最后的Softmax函数作用前,在各个标签上的分数z_i,称为Logits。当Softmax函数作用在Logits上,会得到各个标签上的概率值p_i,总和为1。

基于Logits概念,当我们的概率值中只有一个为1,其余为0,则称这些概率值分布为Hard-Target;其它情况为Soft-Target。一般,真实的标签表示方法(通常采用One-Hot表示法)为Hard-Target,只有命中的标签为1,其余标签为0;而模型训练产出的标签结果为Soft-Target,因为不存在概率为1的标签,只有无限接近于1的标签。

通过上述的知识蒸馏模型结构介绍,我们知道,Student模型会利用Teacher模型产生的Soft-Target。那么,Soft-Target有何用处呢?

我们来看个Soft-Target的例子[5]。

 

使用知识蒸馏提升模型推理性能_第2张图片

图2: Soft Target & Hard Target

观察上面的两个样本,他们的真实标签均为数字2,因此Hard-Target一致。但我们观察它们的Soft-Target,第一个值的预测结果为2,但在数字3上面的概率值会比其它数字更高,因为从图片中看这个数字2,有点像数字3;同理,第二个值在数字7上的概率比其它数字更高,有点像数字7,这从图片中也能得到反映。

因此,与Hard-Target相比,Soft-Target能够反映样本特征的更多信息,有其合理之处,这就是为什么我们要利用Teacher模型的Soft-Target,它在Hard-Target之外,还能告诉Student模型更多关于样本的特征信息,因此对Student模型有指导意义。

Temperature

那么,Teacher模型的Soft-Target对Student模型有指导意义,还能再加强Soft-Target的作用吗?参数Temperature便应运而生。

对于Softmax函数,我们有如下公式:

 

a7add30265d85ad34aebf4e7d443dbf8.png

其中,z_i为Logits值,p_i为概率值。我们将温度系数T作用在该公式中,得到:

 

73dc17a7e6dcdc216e5df538d2a39e07.png

由上述公式,我们可得到:

  • • 随着T的增大,各个概率值将趋向平滑(当T为无穷大时,概率值相同),其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签

  • • 随着T的减小,各个概率值将趋向陡峭,分布的熵越小,负标签携带的信息被相对放小,模型训练更少关注负标签

我们来做个小小的实验,以原始Logits分布[-1, 1, 3, 2, 0.5]为例,考察温度T对概率值分布的影响:

 

使用知识蒸馏提升模型推理性能_第3张图片

图3: 温度T对概率值分布的影响

Python实现代码如下:

from math import exp
import matplotlib.pyplot as plt


def softmax(logit_list):
    s_sum = sum([exp(_) for _ in logit_list])
    return [exp(_)/s_sum for _ in logit_list]


logits = [-1, 1, 3, 2, 0.5]

for i, T in enumerate([0.2, 0.5, 0.8, 1, 2, 5, 10]):
    post_logits = softmax([_/T for _ in logits])
    if i < 3:
        j = i + 1
    elif i == 3:
        j = i + 2
    else:
        j = i + 3
    plt.subplot(3, 3, j)
    plt.bar(range(len(post_logits)), post_logits, color=list('bbrbb'))
    plt.title(f"Temperature: {T}")
    plt.xticks([])

plt.show()

理论介绍

知识蒸馏介绍章节中,我们已经介绍了知识蒸馏的模型结构图(图1)。同时,还介绍了知识蒸馏的步骤。

  1. 1. 在原有训练数据集中训练好Teacher模型,一般为复杂的(大)模型;

  2. 2. 借助温度参数(Temperature,T)和Teacher模型的Logits结果,产生Soft Labels;

  3. 3. 在相同训练集上训练小模型(Student模型),最终loss为两部分loss的权重和:大模型的Soft Labels和小模型的Soft Labels(Predictions)的K-L loss;小模型在数据集上的交叉熵损失

  4. 4. 使用训练好的小模型做最终的模型推理(Inference)

第1步为常规的(大)模型训练,第2,3步被称为蒸馏。在2,3步中,最终的输出loss表示如下:

 

db75e718f38b755fcf29b4d04af25813.png

其中,Teacher模型和Student模型的Soft Target如下:

 

使用知识蒸馏提升模型推理性能_第4张图片

L_soft为Teacher模型和Student模型的Soft Target的KL-Loss, L_hard为Student模型在训练集真实标签上的交叉熵。最终Loss为两部分的权重和。

由于L_soft贡献的梯度大约为L_hard的1/(T^2),因此在同时使用Soft-target和Hard-target的时候,需要在L_soft的权重上乘以T^2,这样才能保证Soft-target和Hard-target贡献的梯度量基本一致。

实验发现,当L_hard权重较小时,能产生最好的效果,这是一个经验性的结论。

实验代码

在本文中,数据集采用Sougou小样本数据集,Teacher模型采用BERT训练,模型名称为bert_base_sougou_trainer_128/checkpoint-96,Student模型采用ckiplab/bert-tiny-chinese模型[6]。实验过程详细介绍如下:

  • • 导入模型名称

student_id = "ckiplab/bert-tiny-chinese"
teacher_id = "./bert_base_sougou_trainer_128/checkpoint-96"
  • • 验证tokenizer

from transformers import AutoTokenizer

# init tokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id)
student_tokenizer = AutoTokenizer.from_pretrained(student_id)

# sample input
sample = "这是一个基本例子,使用不同的汉字进行测试。"

# assert results
print(teacher_tokenizer(sample), student_tokenizer(sample))
assert teacher_tokenizer(sample) == student_tokenizer(sample), "Tokenizers haven't created the same output"
  • • 加载数据集

# load dataset
import datasets
data_files = {"train": "./sougou/train.csv", "test": "./sougou/test.csv"}
raw_datasets = datasets.load_dataset("csv", data_files=data_files, delimiter=",")
  • • 加载tokenizer并对文本进行tokenize

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(teacher_id)
from transformers import DataCollatorWithPadding

def tokenize_function(sample):
    return tokenizer(sample['text'], max_length=128, truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  • • 创造知识蒸馏网络的训练参数和Trainer

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)

        self.alpha = alpha
        self.temperature = temperature

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):

        # compute student output
        outputs_student = model(**inputs)
        student_loss=outputs_student.loss
        # compute teacher output
        with torch.no_grad():
          outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss
  • • 设置训练参数,加载模型

from transformers import AutoModelForSequenceClassification

# define training args
training_args = DistillationTrainingArguments(
    output_dir='kd_sougou_trainer_128',
    evaluation_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=5e-5,
    num_train_epochs=5,
    warmup_ratio=0.2,
    logging_dir='./sougou_train_logs',
    logging_strategy="epoch",
    save_strategy="epoch",
    report_to="tensorboard",
    # distilation parameters
    alpha=0.5,
    temperature=1.0
    )

# define model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=5
)

# define student model
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    num_labels=5
)
  • • 创建准确率计算指标

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }
  • • Trainer类实例化,进行训练

trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

实验结果

在上述的实验代码下,我们分别对单独Student模型训练、蒸馏过程、单独Student模型Optuna参数优化、蒸馏过程Optuna参数优化进行实验,统计在测试集上的Weighted F1值,如下:

模型 Weighted F1
Teacher模型(单独) 0.9737
Student模型(单独) 0.9050
蒸馏 0.9050
Student模型(单独,参数优化) 0.9331
蒸馏(参数优化) 0.9454

在本地实验中,如果不进行参数优化,则蒸馏的效果不一定会比单独训练Student模型效果来得好;但进行参数优化后,蒸馏效果优于单独训练Student模型,且只比Teacher模型下降了2.8%。

对比蒸馏后的小模型和Teacher模型的平均推理时间,结果如下:

基座模型 推理时间(ms) 模型大小
bert-base-chinse 341.5 ~412MB
bert-tiny-chinese(蒸馏后) 45.55 ~46MB

推理速度提升了7.5倍,这比量化策略的提升1.8倍强太多了。

 

你可能感兴趣的:(AI(人工智能),内容分享,NLP(自然语言处理)内容分享,深度学习,人工智能)