[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction

Motivation

论文地址:URL
双塔模型缺乏文本对之间的交互,性能较差。有一些工作提出了后交互策略,包括MLP layers,cross-attention layers,Transformer layers等。然而,这些交互模块是在编码器之后添加的,而编码器编码过程中的交互仍然被忽略,与基于交互的模型相比,有很大的性能差距。

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第1张图片

Methodology

Preliminaries

Interaction-based Models

给定两个输入文本序列 X X X Y Y Y,该方法将它们拼接到一起变成 [ X ; Y ] [X;Y] [X;Y],然后使用Transformer encoder对它们进行编码。Transformer每一层包含了两个子层:多头自注意力操作(MHA)和前馈网络(FFN):
M ( l ) = s o f t m a x ( A t t ( Q ( l ) , K ( l ) ) ) H ^ ( l ) = L N ( M ( l ) V ( l ) + H ( l − 1 ) ) H ( l ) = L N ( F N N ( H ^ ( l ) ) + H ^ ( l − 1 ) ) M^{(l)} = softmax(Att(Q^{(l)},K^{(l)})) \\ \hat{H}^{(l)}=LN(M^{(l)}V^{(l)}+H^{(l-1)})\\ H^{(l)} = LN(FNN(\hat{H}^{(l)})+\hat{H}^{(l-1)}) M(l)=softmax(Att(Q(l),K(l)))H^(l)=LN(M(l)V(l)+H(l1))H(l)=LN(FNN(H^(l))+H^(l1))
其中 A t t ( Q , K ) = Q K T d Att(Q,K)=\frac{QK^T}{\sqrt{d}} Att(Q,K)=d QKT

Representation-based Models

表征模型通过两个独立的Siamese Transformer encoder对 X 和 Y 分别编码: H ~ x L = E n c x ( X ) , H ~ y L = E n c y ( y ) \widetilde{H}^L_x = Enc_x(X), \widetilde{H}^L_y = Enc_y(y) H xL=Encx(X),H yL=Ency(y)

VIRT

表征模型的主要弱点是在单独对两个输入序列进行编码时缺乏交互。

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第2张图片

MHA Analysis

交互模型的MHA操作如图2(b)中蓝色注意力图所示。该模型第l层的输入H可以拆分为X和Y两个部分, H = [ H x ; H y ] H=[H_x;H_y] H=[Hx;Hy].在attention图的计算中,Q,K矩阵同样可以拆分为两部分, Q = [ Q x ; Q y ] , K = [ K x ; K y ] Q=[Q_x;Q_y],K=[K_x;K_y] Q=[Qx;Qy],K=[Kx;Ky]。这样,最终的attention分数(S)就可以被分解为以下分块矩阵:
S = A t t ( [ Q x ; Q y ] , [ K x ; K y ] ) = [ A t t ( Q x , K x ) A t t ( Q x , K y ) A t t ( Q y , K x ) A t t ( Q y , K y ) ] = [ S x → x S x → y S y → x S y → y ] \begin{aligned} S &= Att([Q_x;Q_y],[K_x;K_y])\\ &=\begin{bmatrix} Att(Q_x,K_x) & Att(Q_x,K_y) \\ Att(Q_y,K_x) & Att(Q_y,K_y) \end{bmatrix} \\ &= \begin{bmatrix} S_{x\rightarrow x} & S_{x\rightarrow y} \\ S_{y\rightarrow x} & S_{y\rightarrow y} \end{bmatrix} \\ \end{aligned} S=Att([Qx;Qy],[Kx;Ky])=[Att(Qx,Kx)Att(Qy,Kx)Att(Qx,Ky)Att(Qy,Ky)]=[SxxSyxSxySyy]
其中, S x → x S_{x\rightarrow x} Sxx S y → y S_{y\rightarrow y} Syy 是只在X或Y中进行了MHA。 S x → y S_{x\rightarrow y} Sxy S y → x S_{y\rightarrow x} Syx 表示交互模型中X和Y之间的交互,它们能够使用交互信息来丰富token的表示。然而,这一部分在表征模型中是缺失的,如图2(b)所示。

Interactive Knowledge Transfer

为了弥补缺失的交互,我们让表征模型模拟交互:
M ~ x → y = s o f t m a x ( A t t ( Q ~ x , K ~ y ) ) M ~ y → x = s o f t m a x ( A t t ( Q ~ y , K ~ x ) ) \widetilde{M}_{x\rightarrow y} = softmax(Att(\widetilde{Q}_x,\widetilde{K}_y))\\ \widetilde{M}_{y\rightarrow x} = softmax(Att(\widetilde{Q}_y,\widetilde{K}_x)) M xy=softmax(Att(Q x,K y))M yx=softmax(Att(Q y,K x))
这两个额外的注意图代表了表征模型中缺失的交互信号,然而,它们不能直接从表征模型中的编码器中计算出来,从而导致产生了不够有效的文本表示。为了缩小表征模型和交互模型之间的性能差距,我们建议将表征模型中缺失的注意力图与交互模型中已经存在的对应图对齐。通过这种方式,我们提炼出交互中的知识,并将其转移到双重编码器中,而没有任何额外的推理计算成本。这就是为什么我们把这种机制称为 “虚拟交互”。

我们使用一个训练好的交互模型作为老师,将它的知识蒸馏给学生,即表征模型。在每一层中,我们从交互模型中获得了注意力图 M x → y M_{x\rightarrow y} Mxy M y → x M_{y\rightarrow x} Myx,并将这些有监督的交互知识用于指导表征模型的学习。从形式上看,我们的目标是使所有层中 M ~ x → y \widetilde{M}_{x\rightarrow y} M xy M ~ y → x \widetilde{M}_{y\rightarrow x} M yx M x → y M_{x\rightarrow y} Mxy M y → x M_{y\rightarrow x} Myx之间的L2距离最小:
L v i r t = 1 2 L ∑ l = 1 L ( 1 m ∥ M ~ x → y ( l ) − M x → y ( l ) ∥ 2 + 1 n ∥ M ~ y → x ( l ) − M y → x ( l ) ∥ 2 ) \mathcal{L}_{virt} = \frac{1}{2L}\sum_{l=1}^L(\frac{1}{m}\Vert\widetilde{M}^{(l)}_{x\rightarrow y}-M^{(l)}_{x\rightarrow y} \Vert_2 +\frac{1}{n}\Vert\widetilde{M}^{(l)}_{y\rightarrow x}-M^{(l)}_{y\rightarrow x} \Vert_2) Lvirt=2L1l=1L(m1M xy(l)Mxy(l)2+n1M yx(l)Myx(l)2)
上述蒸馏只在训练阶段应用,以学习更好的双编码器。这保留了表征模型的特性,而没有额外的推理成本。

VIRT-Adapted Interaction

通过VIRT,交互式知识可以被汇入到表征模型的每个编码层中。然而,在Siamese编码之后,最后一层的表征仍然无法看到对方,因此缺乏明确的互动。为了充分利用学到的交互知识,我们进一步设计了一个与VIRT相适应的交互策略,在VIRT学到的注意图的指导下,将 H ~ x L \widetilde{H}^L_x H xL H ~ y L \widetilde{H}^L_y H yL 融合起来。

具体地,我们在 H ~ x L \widetilde{H}^L_x H xL H ~ y L \widetilde{H}^L_y H yL 之间进行VIRT-Adapted的交互:
M ^ x → y ( L ) = s o f t m a x ( A t t ( H ~ x ( L ) , H ~ y ( L ) ) ) M ^ y → x ( L ) = s o f t m a x ( A t t ( H ~ y ( L ) , H ~ x ( L ) ) ) u = P o o l ( M ^ x → y ( L ) H ~ y ( L ) ) v = P o o l ( M ^ y → x ( L ) H ~ x ( L ) ) \hat{M}^{(L)}_{x\rightarrow y} = softmax(Att(\widetilde{H}^{(L)}_x,\widetilde{H}^{(L)}_y))\\ \hat{M}^{(L)}_{y\rightarrow x} = softmax(Att(\widetilde{H}^{(L)}_y,\widetilde{H}^{(L)}_x))\\ u = Pool(\hat{M}^{(L)}_{x\rightarrow y} \tilde{H}^{(L)}_{y})\\ v = Pool(\hat{M}^{(L)}_{y\rightarrow x} \tilde{H}^{(L)}_{x}) M^xy(L)=softmax(Att(H x(L),H y(L)))M^yx(L)=softmax(Att(H y(L),H x(L)))u=Pool(M^xy(L)H~y(L))v=Pool(M^yx(L)H~x(L))
其中,Pool是平均池化。最终,我们进行简单的融合来做预测:
r = ( u , v , u − v , m a x ( u , v ) ) y = s o f t m a x ( M L P ( M L P ( r ) + r ) ) r = (u,v,u-v,max(u,v))\\ y = softmax(MLP(MLP(r)+r)) r=(u,v,uv,max(u,v))y=softmax(MLP(MLP(r)+r))
整体的训练损失是任务特定的有监督损失和蒸馏损失的结合,如下:
L = L t a s k + α L v i r t \mathcal{L} = \mathcal{L}_{task}+\alpha\mathcal{L}_{virt} L=Ltask+αLvirt
α \alpha α 是超参,本文中设为1。

Experiments

Datasets

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第3张图片

Experimental Setup

双塔两个塔参数共享。pooling策略选择平均pooling(而不是[CLS]),平均pooling比[CLS]效果更好。

Results

Main Results

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第4张图片

Ablation Study

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第5张图片

对于MNLI和RTE任务来说,去除 adapted interaction 引起的性能下降更为严重。我们的假设是,MNLI和RTE是自然语言推理任务,需要更精细的匹配信号,并且严重依赖显式交互。

Layer Importance

我们将VIRT应用于双编码器中不同的选定层,以了解不同编码器层中交互知识的重要性。

  1. VIRT-Last:只对最后k层应用VIRT

  2. VIRT-First:只对前k层应用VIRT

  3. VIRT-Skip:将VIRT应用于每k层中的第1层

  4. VIRT-All:对所有层应用VIRT

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第6张图片

我们观察到,当激活6层时,VIRT-First比VIRT-Last和VIRT-Skip表现得更好,这表明来自底层的交互知识起到了关键作用。(Wang et al., 2020) 表明蒸馏最后一层的信息就足够了。然而,当我们只在最后一层应用了VIRT时,我们发现,当教师模型和学生模型是异质的,仅仅蒸馏最后一层的信息会面临很大的性能下降。

Impact of VIRT Distillation

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第7张图片

Different Model Configurations

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第8张图片

Impact of α \alpha α

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第9张图片

Case Study

与没有VIRT的表征模型相比,有VIRT蒸馏的表征模型的注意力矩阵与基于交互的模型更加一致。

[EMNLP 2022] VIRT: Improving Representation-based Text Matching via Virtual Interaction_第10张图片

你可能感兴趣的:(文本匹配,双塔模型,Paper,notes,机器学习,人工智能)