文章基于prompt的一致性学习给出了一种zero-shot task generalization(零样本泛化学习)的无监督方法。数值实验表明,文章提出的指令一致性学习方法只需在几个prompt、几十个样本上进行训练,就可以在NLI等NLP任务上追平SOTA水平。
文章整体架构如下
首先简单介绍下zero-shot task generalization(零样本泛化学习):给定输入 x ∈ X x\in \mathcal{X} x∈X,零样本泛化学习旨在学习一个预训练模型PLM预测出 y ∈ Y y\in \mathcal{Y} y∈Y,其中PLM未在数据集 X \mathcal{X} X上训练过。零样本泛化学习要求模型可以泛化出一个新的表达式 f : X → Y f: \mathcal{X} \to \mathcal{Y} f:X→Y,而非仅仅在数据集上具有泛化能力。
给定prompt r r r, r r r包含一个输入模板 r x r_x rx、输出模板 r y r_y ry以及待放入模板的元数据 x , y x, y x,y,我们可以得到prompt-based输入: r x ( x ) , r y ( y ) r_x(x), r_y(y) rx(x),ry(y)。基于prompt的学习方法一般用 p θ ( r y ( y ) ∣ r x ( x ) ) p_{\theta} (r_y(y)|r_x(x)) pθ(ry(y)∣rx(x))来计算输出的概率 q ( y ∣ x , r ) ) q(y|x, r)) q(y∣x,r)),其中 θ \theta θ表示模型的参数。本文重点关注NLP的分类任务,则可以通过如下公式计算输出的概率: q ( y ∣ x , r ) = p θ ( r y ( y ) ∣ r x ( x ) ) ∑ y ′ ∈ Y p θ ( r y ( y ′ ) ∣ r x ( x ) ) (1) q(y|x, r) = \frac{p_{\theta} (r_y(y)|r_x(x))}{\sum_{y'\in\mathcal{Y}} p_{\theta} (r_y(y')|r_x(x))}\tag{1} q(y∣x,r)=∑y′∈Ypθ(ry(y′)∣rx(x))pθ(ry(y)∣rx(x))(1)。
文章的方法需要无标注的数据集 { x 1 , … , x N } \{x_1, \dots, x_N\} {x1,…,xN}和 K K K个不同的prompt { ( r x 1 , r y 1 ) , … , ( r x K , r y K ) } \{(r_x^1, r_y^1), \dots, (r_x^K, r_y^K)\} {(rx1,ry1),…,(rxK,ryK)}。其中无标注的数据集可以来自任意NLP(分类)任务的训练数据集或测试数据集,也可以来自我们要测试的任务的数据集。prompt可直接采用Public Pool of Prompts(p3)数据集里的prompt。
传统的一致性训练会扰乱样本,使得扰乱后的样本和之前的样本得到的输出尽可能一致。本文希望学习prompt级别的一致性,即不同prompt在单个样本上的学习结构尽可能一致。这样做可以1) 概念非常简单 2)缓解PLM“输入不同prompt结果不一致”的问题。
损失函数定义如下 L = − E x ∈ p d ( x ) E r i , e r j ∈ p ( r ) E y ^ ∈ q ^ ( y ∣ x , r i ) log p θ ( r y j ( y ^ ) ∣ r x j ( x ) ) \mathcal{L} = -\mathbb{E}_{x\in p_d(x)} \mathbb{E}_{r^i, er^j\in p(r)} \mathbb{E}_{\hat{y} \in \hat{q}(y|x,r^i)} \log p_{\theta} (r_y^j(\hat{y})|r_x^j(x)) L=−Ex∈pd(x)Eri,erj∈p(r)Ey^∈q^(y∣x,ri)logpθ(ryj(y^)∣rxj(x))
, p d p_d pd是数据集的分布, p ( r ) p(r) p(r)表示 K K K个prompt的随机prompt对的均匀分布, q ^ \hat{q} q^定义为式(1)的条件分布。这里简单解释下,如图所示,给定prompt r i , r j r^i, r^j ri,rj,我们首先预测 y ^ ∈ q ^ ( y ∣ x , r i ) \hat{y}\in \hat{q}(y|x, r^i) y^∈q^(y∣x,ri),即当promt为 r i r^i ri时得到输出 y ^ \hat{y} y^。当prompt为 r j r^j rj时,我们希望最大化输出结果为 y ^ \hat{y} y^(即和 r i r^i ri输出相同)的概率 p θ ( r y j ( y ^ ) ∣ r x j ( x ) ) p_{\theta} (r_y^j(\hat{y})|r_x^j(x)) pθ(ryj(y^)∣rxj(x)),取负对数和期望之后,即得到上述损失函数。我们称上述训练方法为swarm distillation。
如果直接采用上述方法进行训练,则我们很容易collapse,得到一个平凡解:所有prompt、所有样本均输出同一个结果可以实现损失函数最小。另一方面,训练后的模型可以能忘记之前的知识,即castrophic forgetting。为了避免collapse和catastrophic forgetting,文章提出下述两种方法:
文章提出了一种基于prompt一致性的zero-shot task generation学习方法swarm distillation,且采用了LoRA和Fleiss’ Kappa方法避免学习灾难性遗忘或学习结果collapse。文章在多个NLP下游任务上进行了验证,发现swarm distillation在多个任务上表现超过SOTA。此外,数值实验表明,swarm distillation只需要4个prompt,10+个样本就可以对源模型(T0)进行提升。
但实验也表明,swarm distillation方法在增加到一定样本量之后性能就达到了饱和,当我们有很多标记样本可用的时候,性能可能不及监督微调。未来可以将swarm distillation与few-shot少样本学习或预训练相结合来实现在标记样本上的性能提升。
Prompt Consistency for Zero-Shot Task Generalization