Github:https://github.com/sunnweiwei/MixCL
给定一个问题或上下文 x x x,一个对应检索的知识 K \mathcal{K} K,目标是根据上下文和知识来生成回复 y y y。
目前对话有两种模式,如下图:
本文则关注LM模式
(1)Pre-training:采用BART作为语言模型:
(2)SFT(Fine-tuning):采用MLE目标在对话数据集上进行自回归式训练:
然而MLE损失鼓励模型盲目模仿训练数据并导致模型幻觉,其过度依赖于前面的token,容易导致误差传播。
研究发现,使用标准 MLE 训练的模型可能会过度依赖之前预测的标记,从而加剧错误传播(Wang 和 Sennrich 2020)。 结果,在推理阶段,随着生成序列的增长,错误沿着序列累积,模型往往会放大错误并产生幻觉内容。
Studies have found that models trained with standard MLE may over-rely on previously predicted tokens, exacerbating error propagation (Wang and Sennrich 2020). As a result, during the inference stage, as the generated sequence grows, the errors accumulate along the sequence, and the model tends to amplify errors and generate hallucinating contents.
本文提出MixCL,一种基于混合对比学习的训练策略来降低模型幻觉。
方法如下图所示:
主要包括两个核心步骤:Negative Sampling和Mixed Contrastive Learning
z + z^{+} z+表示正确的知识或文本片段,代表positive,其通过一个函数 Q P o s ( x ) Q_{Pos}(x) QPos(x)来实现positive的获取。该函数输入的是原始的文本 x x x,输出正确的知识片段,可以是人工标注,也可以是启发式规则。
z − z^- z−表示negative,即non-factual或与输入 x x x存在不相关的知识(irrelevant knowledge)片段。本文设计两种获得 z − z^- z−的方法:
(1)检索式:采用TF-IDF retriever,给定输入文本 x x x和一个知识库 K \mathcal{K} K,输出一组 z − z^- z−。由于采用TF-IDF,采样得到的片段与输入文本存在一定的confusion,但依然是negative;
(2)模型生成式:提出一种bootstrapping策略,在模型生成时获得negative
使用NLI工具约束模型生成的片段不包含正确的知识。
基于上述两个方法,最终构建得到负采样函数:
首先对比学习的loss设计如下所示:
l l l表示cross-entropy loss, M M M为负样本的数量。
在BERT或GPT模式的训练中,通常 l l l要么是基于token的loss,要么是基于sentence的loss。然而模型产生的幻觉通常是一个文本区间(span),因此本文提出基于span的对比学习。
(1)抽取区间
首先要从positive和negative文本中分别抽取区间。
考虑到幻觉有内部幻觉和外部幻觉,因此设计两种span抽取策略。
(2)构建Mixing example
参考Mix-up等工作,将一个正样本和负样本进行mix-up: z ~ = M i x ( z + , z − ) \tilde{z}=Mix(z^+, z^-) z~=Mix(z+,z−)。
具体操作如下所示:
其实,0/1表示的是混合后的序列 z ~ \tilde{z} z~对应的token是负样本/正样本。
(3)Contrastive Loss
对于整个数据集,给定一个输入 x x x,先获得对应的一个正样本 z + z^+ z+,然后采样获得 M M M个负样本 z i − z_i^- zi−。
所有输入 x x x对应的总的loss定义如下:
对于某一个正样本 z + z^+ z+和负样本 z i − z_i^- zi−的pair,其loss定义如下所示:
其中 z ~ i = M i x ( z + , z i − ) \tilde{z}_i=Mix(z^+, z_i^-) z~i=Mix(z+,zi−), ∣ z ~ i ∣ |\tilde{z}_i| ∣z~i∣表示这个序列的token数量, ϕ i j \phi_{ij} ϕij表示 z ~ i \tilde{z}_i z~i的第 j j j个token是否是positive。
可知该loss依然是站在基于token的Causal Languege Modeling目标,但是不同的是,对应的token有的是来自positive,有的是negative,negative token可以认为是训练过程中模拟的幻觉部分。
最终总的训练loss为:
初始化时, α 1 = 0.4 \alpha_1=0.4 α1=0.4, α 2 = 0.3 \alpha_2=0.3 α2=0.3, α 3 = 0.3 \alpha_3=0.3 α3=0.3
随后这些参数进行线性变化,最终 α 1 = 0.5 \alpha_1=0.5 α1=0.5, α 2 = 0.5 \alpha_2=0.5 α2=0.5, α 3 = 0 \alpha_3=0 α3=0。
之所以一开始 α 3 > 0 \alpha_3>0 α3>0,目的是为了防止模型灾难性遗忘。
Wizard-of-Wikipedia(WoW)
F1、ROUGE-L、BLEU(2/4)、MT、Knowledge-F1、Entity-F1、Acc。
F1 (Dinan et al. 2019) calculates the unigram F1 between the generated text and the ground-truth text. For ROUGE (Lin 2004) we use ROUGE-L (RL for short) following previous work. BLEU (Papineni et al. 2002) we use BLEU-2 and BLEU-4 (or B2 and B4 for short) and use the implementation in the NLTK Toolkit. MT (Meteor) (Denkowski and Lavie 2014) is based on the harmonic mean of unigram precision and recall. Knowledge-F1 (Dinan et al. 2019) (or KF1 for short) calculates the F1 between the generated response and the ground-truth knowledge sentence, which indicates the informativeness of a response. Acc measures the knowledge selection accuracy. As we skip the knowledge selection step, we select knowledge by matching the generated response with each knowledge candidate in WoW using the F1 score. Entity-F1 (or EF1 for short) identifies entities in text using Spacy, deletes the non-entity words, and calculates the F1 score between the modified generated text and the ground- truth response. EF1 eliminates the impact of the stop-word and focuses on the accuracy of entities.
这些评价指标的实现参考:https://github.com/sunnweiwei/MixCL/blob/main/utils/evaluation.py
另外邀请新的三个标注人员对测试样本中的100个样本进行标注,从四个方面进行打分。
Informativeness(0、1、2分), which measures whether the response is knowledge-inclusive; Relevancy(0、1、2分), which measures whether the response’s content is relevant to the dialogue; Factuality(0或1分), which measures whether the information in the response is factually correct; and Humanlikeness(0、1、2分), which measures whether the response is human-like in its fluency and naturalness.
backbone选择BART-Large(400M),知识库则为Wikipedia
(1)自动评估
可知在各种指标上效果都是提升比较明显的。
(2)人工评估
克制提出的MixCL在人工打分上也是最高的,部分指标也逼近人类回复的打分。
(3)消融实验
模型训练使用了三个loss和两个采样函数。发现如果缺少使用一个部分呢,效果都会下降。但是指标上下降也并不明显
(4)有效性验证
横轴表示模型生成结果的等待时间,纵轴为F1值。
可知我们的方法用最少的latency(等待时间)获得了最佳的F1值,说明整体性能是很优的。
(5)Case Study