ICLR2022 | GREASE LM: 图推理增强QA上的LM

本文是Christopher D. Manning和Jure Leskovec合作的一篇文章,NLPer和GNNer应该对这两个名字不陌生,一个讲了CS224N,一个讲了CS224W。

本文提出一种用图推理增强QA的LM的架构:GREASE LM。GREASE LM通过对LM和GNN进行多层次深度交互,有效捕捉导GNN的结构信息与LM的语义约束,从而提升了模型在QA任务上的性能,和处理复杂问题的能力。

研究背景

回答文本相关的问题需要从陈述的上下文(context)和它背后的知识(knowledge)来推理。

QA任务的基本范式是pre-trained LM。尽管在普通的benchmark上表现良好,但是当用于测试的example的分布与精调的数据集不相同时,模型的表现会变的挣扎。因为这种学习方式希望通过简单(偶尔是错误)的模式匹配直接走捷径获得答案,虽然能从LM中捕捉到context的情景约束和细微差别,但不能很好的表达概念之间的潜在关系。我们希望能用context提供的显性信息和内在的隐性知识来进行健壮的、结构化的推理。

Knowledge Graph(KG)由描述实体关系的三元组构成,被认为蕴含有大量的知识,在QA之类的推理任务上作用显著,因此用KG给LM扩充知识的结构化表示成为热门问题。然而,要将KG的推理优势扩展到一般的QA(问题和答案用自然语言表示,不容易转换成严格的逻辑查询),需要根据QA example提供的信息和约束找到KG中正确的知识集成。

早期LM+KG的工作通常将两种模态在浅层并且以非交互的方式融合,比如各自独立编码最后预测时再融合,或者用一个模态去加强另一个模态的输入。一般分为三种:

  1. 双塔结构,没有交互;
  2. KG支撑LM,比如用KG的编码来增强QA example的文本表示;
  3. 第三种是LM支撑KG,用文本表示(LM的最后一层)去增强从example提取的KG。

这三种结构信息流动最多有一个方向,这样两种模态信息交互的能力有限。如果想更好的模拟结构化的情景推理,需要实现双向的信息交互,让KG获得LM的细微语义约束,KG和LM进行深层次的融合。最近的一些工作探讨了这两种模态更深层次的集成。有的通过在结构化的KG数据上训练LM来将隐层知识编码进LM,然后用LM去生成针对QA的小KG。但是这样KG在转化成文本之后结构信息就丢掉了。QA-GNN用消息传递来联合更新LM和GNN,但是他们将LM池化作为整个架构的文本组件,限制了文本表示的更新能力。还有一些工作在预训练的时候将KG融入到LM中,但是模态的交互方向更多是将知识送到语言中。 Sun的工作与GREASE LM相似,但是它交互的bottleneck需要高精度的实体链接;并且LM和KG的参数共享限制了表达能力。

模型架构

符号定义:
MCQA(multiple choice question answering)的数据集包括上下文段落 c c c ,问题 q q q ,答案候选集 A A A ,都以文本的形式表示。本工作中,还假设有额外的知识图谱 G G G 提供背景知识。

给定QA example ( c , q , A ) (c,q,A) (c,q,A) 和KG G G G ,判断正确的答案 a ∈ A a \in A aA

我们将自然语言中的token序列表示为 { w 1 , … , w T } \{w_1,\dots,w_T \} {w1,,wT} ,其中 T T T 表示token数量,token w t w_t wt l l l -th layer的表示为 h t ( l ) h_t^{(l)} ht(l) 。KG的点集表示为 { e 1 , … , e J } \{e_1,\dots,e_J\} {e1,,eJ} ,其中 J J J 表示节点的数量。节点 e j e_j ej l l l -th layer的表示为 e j ( l ) e_j^{(l)} ej(l)

输入表示
先将 c , q , a c,q,a c,q,a 和分离符并起来作为模型输入 [ c ; q ; a ] [c;q;a] [c;q;a] ,转换成token序列 { w 1 , … , w T } \{w_1,\dots,w_T \} {w1,,wT} ;然后用输入序列去检索(retrieval)出 G G G 的子图 G s u b G_{sub} Gsub , G s u b G_{sub} Gsub 提供跟QA example相关的知识. G s u b G_{sub} Gsub 的点集表示为 { e 1 , … , e J } \{e_1,\dots,e_J\} {e1,,eJ} .

KG Retrieval
首先根据文本中的实体从 G G G 中链接出一个初始点集 V l i n k e d V_{linked} Vlinked 。然后将 V l i n k e d V_{linked} Vlinked 中任意点对之间的2-hop路径(长度为2,也就是中间只有一个点,也就是桥点)的桥点加进去形成 V r e t r i e v e d V_{retrieved} Vretrieved 。然后再对 V r e t r i e v e d V_{retrieved} Vretrieved 里的点计算相关分数(relevance score) : 将node name并在QA example后面,通过LM得到node name的output score,作为relavance score。我们取 V r e t r i e v e d V_{retrieved} Vretrieved 中分数最高的200个点为 V s u b V_{sub} Vsub ,剩下的都扔掉。最后,将所有链接两个 V s u b V_{sub} Vsub 中的点的边加进去形成 G s u b G_{sub} Gsub 。另外, G s u b G_{sub} Gsub 里的每个点都做一个标记,标记这个点对应的实体是从哪里来的,来自上下文 c c c / 询问 q q q / 答案 a a a / 这些点的相邻节点。本文之后的KG都是表示 G s u b G_{sub} Gsub .


图1. GREASELM模型架构图

GREASE LM 整体架构有两个堆叠组件:

  1. 单模态的LM层*N:获得输入token的初始表示
  2. 交叉模态的GREASELM层*M:将LM层的文本表示与KG的图表示融合在一起

Language Pre-Encoding

{ w 1 , … , w T } \{w_1,\dots,w_T \} {w1,,wT} 的token、段、位置嵌入求和作为 l = 0 l=0 l=0 时的表示 { h i n t ( 0 ) , h 1 ( 0 ) , … , h T ( 0 ) } \{h_{int}^{(0)},h_1^{(0)},\dots,h_T^{(0)}\} {hint(0),h1(0),,hT(0)} 。之后就用LM-layer运算出每一层的表示。LM-layer的参数初始为预训练的结果。

{ h i n t ( l ) , h 1 ( l ) , … , h T ( l ) } = L M − l a y e r ( { h i n t ( l − 1 ) , h 1 ( l − 1 ) , … , h T ( l − 1 ) } ) f o r    l = 1 , … , N \{h_{int}^{(l)},h_1^{(l)},\dots,h_T^{(l)}\}=LM-layer(\{h_{int}^{(l-1)},h_1^{(l-1)},\dots,h_T^{(l-1)}\}) \\ for \ \ l=1,\dots,N {hint(l),h1(l),,hT(l)}=LMlayer({hint(l1),h1(l1),,hT(l1)})for  l=1,,N

GreaseLM layer

Interaction Bottlenecks:

首先定义用于交互的 interaction token w i n t w_{int} wintinteraction node e i n t e_{int} eint ,作为两个模态交互的bottlenecks。将 w i n t w_{int} wint 添加到token序列里面,将 e i n t e_{int} eint 链接 G s u b G_{sub} Gsub 中点集 V l i n k V_{link} Vlink 。(不是 G s u b G_{sub} Gsub 所有点)

GreaseLM layer有三个组成部分:

  1. transformer LM encoder block
  2. GNN layer
  3. MInt layer

Language Representation

在第 l l l 层GreaseLM layer,将token embeddings { h i n t ( N + l − 1 ) , h 1 ( N + l − 1 ) , … , h T ( N + l − 1 ) } \{h_{int}^{(N+l-1)},h_1^{(N+l-1)},\dots,h_T^{(N+l-1)}\} {hint(N+l1),h1(N+l1),,hT(N+l1)} 输入到transformer LM encoder block继续编码:
{ h ~ i n t ( N + l ) , h ~ 1 ( N + l ) , … , h ~ T ( N + l ) } = L M − L a y e r ( { h i n t ( N + l − 1 ) , h 1 ( N + l − 1 ) , … , h T ( N + l − 1 ) } ) f o r    l = 1 , … , M \{\widetilde{h}_{int}^{(N+l)},\widetilde{h}_1^{(N+l)},\dots,\widetilde{h}_T^{(N+l)}\}=LM-Layer(\{h_{int}^{(N+l-1)},h_1^{(N+l-1)},\dots,h_T^{(N+l-1)}\})\\ for \ \ l=1,\dots,M {h int(N+l),h 1(N+l),,h T(N+l)}=LMLayer({hint(N+l1),h1(N+l1),,hT(N+l1)})for  l=1,,M

h ~ \widetilde{h} h 表示融合前的embeddings。

之后用于交互的bottleneck h i n t ( N + l ) h_{int}^{(N+l)} hint(N+l) 经过MInt会得到GNN的信息,那么在下一层的transformer LM encoder block的时候, h i n t ( N + l ) h_{int}^{(N+l)} hint(N+l) 会把GNN的信息传递给 h ( N + l + 1 ) h^{(N+l+1)} h(N+l+1)

Graph Representation

G s u b G_{sub} Gsub 中node embedding用MHGRN初始化:使用预定义的模板将KG中的知识三元组转换为句子。然后将句子送到BERT-Large LM中计算嵌入。最后,对于所有包含实体的句子,我们提取这些句子中实体的符号表示,在这些表示上进行均值池化并投影。

经过初始化,得到 { e 1 ( 0 ) , … , e J ( 0 ) } \{e_1^{(0)},\dots,e_J^{(0)}\} {e1(0),,eJ(0)} ,并随机化初始bottleneck e i n t ( 0 ) e_{int}^{(0)} eint(0) 的embedding。
在每一层GNN,做一次消息传递。

{ e ~ i n t ( l ) , e ~ 1 ( l ) , … , e ~ J ( l ) } = G N N ( { e i n t ( l − 1 ) , e 1 ( l − 1 ) , … , e J ( l − 1 ) } ) f o r    l = 1 , … , M \{\widetilde{e}_{int}^{(l)},\widetilde{e}_1^{(l)},\dots,\widetilde{e}_J^{(l)}\}=GNN(\{e_{int}^{(l-1)},e_1^{(l-1)},\dots,e_J^{(l-1)}\})\\ for \ \ l=1,\dots,M {e int(l),e 1(l),,e J(l)}=GNN({eint(l1),e1(l1),,eJ(l1)})for  l=1,,M

具体的更新方式是GAT的一种变种,每个node根据邻居做消息传递更新表示。

e ~ j ( l ) = f n ( ∑ e s ∈ N e j ∪ { e j } α s j m s j ) + e j ( l − 1 ) \widetilde{e}^{(l)}_j=f_n(\sum_{e_s \in N_{e_{j}} \cup \{e_j\}}\alpha_{sj}m_{sj})+e_j^{(l-1)} e j(l)=fn(esNej{ej}αsjmsj)+ej(l1)

N e j N_{e_{j}} Nej 表示 e j e_j ej 的邻域, m s j m_{sj} msj 表示邻点 e s e_s es 传递给 e j e_j ej 的信息, α s j \alpha_{sj} αsj 是用来缩放 m s j m_{sj} msj 的注意力权重, f n f_n fn 是一个两层的MLP。
m s j m_{sj} msj 具体计算方式如下:

r s j = f r ( r ~ s j , u s , u j ) m s h = f m ( e s ( l − 1 ) , u s , r s j ) r_{sj}=f_r(\widetilde{r}_{sj},u_s,u_j) \\ m_{sh}=f_m(e_s^{(l-1)},u_s,r_{sj}) \\ rsj=fr(r sj,us,uj)msh=fm(es(l1),us,rsj)

u s u_s us u j u_j uj 是node type embedding(KG Retrieval最后加的类型标记), r ~ s j \widetilde{r}_{sj} r sj e s e_s es e j e_j ej 的relation embedding, f r f_r fr 是一个两层的MLP, f m f_m fm 是一个线性变换。

α s j \alpha_{sj} αsj 具体计算方式如下:

q s = f q ( e s ( l − 1 ) , u s ) k j = f k ( e j ( l − 1 ) , u j , r s j ) γ s j = q s T k j D α s j = e x p ( γ s j ) ∑ e s ∈ N e j ∪ { e j } e x p ( γ s j ) q_s=f_q(e_s^{(l-1)},u_s) \\ k_j=f_k(e_j^{(l-1)},u_j,r_{sj}) \\ \gamma_{sj}=\frac{q_s^Tk_j}{\sqrt{D}} \\ \alpha_{sj}=\frac{exp(\gamma_{sj})}{\sum_{e_s \in N_{e_{j}} \cup \{e_j\}}exp(\gamma_{sj})} qs=fq(es(l1),us)kj=fk(ej(l1),uj,rsj)γsj=D qsTkjαsj=esNej{ej}exp(γsj)exp(γsj)

f q f_q fq f k f_k fk 都是线性变化。

同理, e i n t e_{int} eint 在获得LM传递过来的信息之后,在下一层的GNN中,会将信息传递给其他的node。

Modality Interaction

在通过LM layer和GNN layer更新过各自的embedding之后,用 modality interaction layer (MInt) 来让两个模态的信息通过 token w i n t w_{int} wint 和 node e i n t e_{int} eint 这两个bottleneck进行融合。作者直接将 h ~ i n t ( l ) \widetilde{h}_{int}^{(l)} h int(l) e ~ i n t ( l ) \widetilde{e}_{int}^{(l)} e int(l)并起来,作为输入通过 MInt 之后,再将混合后的输出分成 h i n t ( l ) h_{int}^{(l)} hint(l) e i n t ( l ) e_{int}^{(l)} eint(l)

[ h i n t ( l ) ; e i n t ( l ) ] = M I n t ( [ h ~ i n t ( l ) ; e ~ i n t ( l ) ] ) [h_{int}^{(l)};e_{int}^{(l)}]=MInt([\widetilde{h}_{int}^{(l)};\widetilde{e}_{int}^{(l)}]) [hint(l);eint(l)]=MInt([h int(l);e int(l)])

MInt为一个两层的MLP,但是也可以用别的融合操作来替换。除了用于交互的 w i n t w_{int} wint e i n t e_{int} eint,其他embedding都保持不变: w ( l ) = w ~ ( l )   f o r   w ∈ { w 1 , … , w T } ,   e ( l ) = e ~ ( l )   f o r   e ∈ { e 1 , … , e J } w^{(l)}=\widetilde{w}^{(l)}\ for \ w \in \{w_1,\dots,w_T\},\ e^{(l)}=\widetilde{e}^{(l)}\ for \ e \in \{e_1,\dots,e_J\} w(l)=w (l) for w{w1,,wT}, e(l)=e (l) for e{e1,,eJ} h i n t ( l ) h_{int}^{(l)} hint(l) e i n t ( l ) e_{int}^{(l)} eint(l) 会在下一层自身模态交互时由传递给这些点。

Learning & Inference

对于MCQA任务,给定问题 q q q ,从候选集 A A A 中选择一个答案 a a a a a a 正确的概率为 p ( a ∣ q , c ) ∝ e x p ( M L P ( h i n t N + M , e i n t M , g ) ) p(a|q,c) \propto exp(MLP(h_{int}^{N+M},e_int^{M},g)) p(aq,c)exp(MLP(hintN+M,eintM,g)) g g g 为将 h i n t N + M h_{int}^{N+M} hintN+M 作为query、对 { e j M ∣ e j ∈ e 1 , … , e J } \{e_j^{M}|e_j \in {e_1,\dots,e_J}\} {ejMeje1,,eJ} 的基于注意力的池化。采用交叉熵作为loss,选择 a r g   m a x a ∈ A   p ( a ∣ q , c ) arg \ max_{a \in A} \ p(a|q,c) arg maxaA p(aq,c) 作为最合理的答案。

实验结果

MCQA数据集

数据集 内容 LM KG GREASE LM相比于LM的性能提升 GREASE LM相比于LM+KG的性能提升
CommonsenseQA 常识 RoBERTa-Large ConceptNet 5.5% 0.9%
OpenbookQA 基本的科学知识 AristoRoBERTa ConceptNet 6.6% 1.8%
MedQA-USMLE 生物医学和临床知识 SapBERT 自建的知识图谱+DrugBank 1.3% 0.5%

表格中LM和KG表示Grease LM采用的LM和KG

Dataset Result



表1. 数据集示例


表2. CommonsenseQA Result

ICLR2022 | GREASE LM: 图推理增强QA上的LM_第1张图片
表3. OpenbookQA Result

表2和表3中Grease LM都要比QA-GNN优秀,说明这样持续的融合比不持续融合的性能更强。
ICLR2022 | GREASE LM: 图推理增强QA上的LM_第2张图片
表4. 与大模型在OpenBookQA上比较

在表5中,Grease LM实现了第3高,相比于参数接近的模型,性能是最高的。

定量分析

作者希望知道模型在更复杂的推理上的表现,但是没有一个明确的方法取衡量命题的推理复杂性。于是作者用3个特性来表述:介词短语的数量(视为显性约束的数量,虽然有时选择正确的answer的过程中会用不上这种约束);否定词(e.g.,no,never)的出现;模糊词(e.g.,sometimes,maybe)的出现。

ICLR2022 | GREASE LM: 图推理增强QA上的LM_第3张图片
表5. 复杂推理的表现

如表5所示,在否定项和模糊项上,Grease LM都显著优于RoBERTa-Large和QA-GNN,说明Grease LM对于细微语意约束捕捉的更好。没有介词短语的时候,QA-GNN强于Grease LM;但是当问题复杂度的上升——介词短语逐渐增加后,Grease LM的表现会好于QA-GNN。QA-GNN的融合方式是将LM对于context的最终表示初始化为GNN的一个node,这种末端融合在一定程度上有效提高了性能,但是这样会在LM与KG交互之前,将整个context压缩成一个向量,严重限制了能被捕捉到的交互信息。

另外一个发现是即便没有介词短语,GREASELM和QA-GNN都比RoBERTa-Large好,可能是因为这些问题不需要推理,但是需要一些特定的常识,这些常识在RoBERTa-Large预训练的时候可能没有学到。

定性分析


图2. 图注意力权重的变化

在图2中,作者检验了Grease LM和QA-GNN各自GNN中node之间的注意力权重,来分析Grease LM的推理步骤是否比QA-GNN更有效。对于从 CommonsenseQA IH-dev拿出的这个例子,GreaseLM做出了正确预测:airplane,而QA-GNN的预测:motor vehicle 是错误的。

对于GreaseLM,从第一层到中间层,“bug”的权重逐渐增加,但是从中间层到最后层,权重下降了,符合“unlikey”的直觉。与此同时,“windshield”的权重从始至终都在增加。凭借着“windshield”与“airplane”之间的链接,“bug”与“car”的负链接,选择了正确的答案。

对于QA-GNN,“bug”的权重始终都在增加,可能是因为“bug”在context中反复出现,转化成GNN的node之后被很好的表示,但是没有像GreaseLM那样被LM重新表述。

泛化性

以上说明了GreaseLM在一般常识推理领域的表现,下面用MedQA-USMLE来评估泛化性。


表6. MedQA-USMLE Result

可以看出GreaseLM要比SapBERT,QA-GNN都要好。说明GreaseLM是一种对于多个领域/KG都适用的LM增广。


表9和表10. LM泛化性结果

为了评估GreaseLM的提升是否与用于使用的LM无关,作者用RoBERTA-BASE在CommonsenseQA上替换了RoBERTA-LARGE,用BioBERT和PubmedBERT在MedQA-USMLE上替换SapBERT。结果表明,将GreaseLM作为KG和LM的模态交互,可以改进多个LM的性能。

消融实验

ICLR2022 | GREASE LM: 图推理增强QA上的LM_第4张图片
表8. 消融实验结果

首先,当没有模态融合时,正确率从78.5%掉到76.5%,相当于QA-GNN。隔层融合也会降低性能,可能是因为影响了学习的连续性,当我们用预训练的LM权重来初始化模型的时候,会产生这样的特性。并且共享MInt的参数比不共享要好,可能因为在数据集不大的时候共享参数避免了过拟合。

对于GreaseLM的层数,当M=5时性能最好,但是M=4或者M=6的时候效果也差不多,说明模型对这个参数不敏感。

对于图的连接性,将 e i n t e_{int} eint 链接到所有的节点相比于只链接到 V l i n k V_{link} Vlink 会产生性能的下降,可能是因为是因为整个子图有200个点,全链接会导致过载。只连接到输入中的实体节点时,这些实体节点可以作为一个过滤器过滤掉不重要的信息。

对于KG node embedding的初始化,用随机权重会导致性能直接从78.5%降到60.8%,用标准的KG embedding(TranE)性能会恢复到77.7%。BERT-based始终是最好的。

你可能感兴趣的:(厚积薄发,深度学习,自然语言处理,语言模型,知识图谱)