现有的工作:
(1)没有关注话语之间的医疗实体关系。
(2)不能充分利用关键信息。
为了对患者的状况有一个全面的了解,医疗对话往往相对较长,且有丰富的医疗专业术语分散在各个话语中,这些话语之间的医疗实体及其复杂关系对捕获对话历史的关键信息、进而诱导回复的生成具有关键作用。
现有的工作在利用不同的话语之间的复杂医疗关系上存在不足,导致现有的医疗对话生成模型无法从长对话历史中获取关键信息以产生准确而有信息量的回复,这主要由于在进行关键信息召回时,交叉注意力机制没有使用显示的监督信号来训练。
如图所示,在第1句的“里急后重”和第4句“肠炎”之间存在症状关系,由于忽略这一医疗关系,生成的回复中可能会缺少关键实体“结肠炎”。
对多个话语之间复杂的医疗关系进行建模,进而显式地指导解码器在回复生成过程中充分利用对话历史中的关键信息。
本文提出了关键信息召回模型(Pivotal Information Recalling,MedPIR),它由两部分组成:知识感知的对话图编码器和召回增强的生成器。知识感知的对话图编码器通过开发话语之间实体之间的关系来构建一个对话图。对话图通过图注意力生成的表征送入生成器中,因此,知识感知的对话图编码器可以从全局对话结构的角度方便生成器使用分布在多个语句中的关键医疗信息。
然后,召回增强的生成器在生成确切的回复之前,通过生成一个对话的摘要来强化对关键信息的使用。召回增强生成器的设计是为了首先显式地从长对话历史中生成关键信息。然后,将召回信息作为回复的前缀来提示,生成更加关注于关键信息的回复。通过这种方式,
(1)召回生成器强制使用交叉注意力机制,以充分利用来自具有召回信号的编码器的关键信息。
(2)此外,召回增强的生成器还通过解码器内部的自注意力机制加强了从对话历史中召回的关键信息与回复之间的交互。
使用BERT-GPT作为backbone,BERT作为编码器,GPT作为解码器。上下文编码器通过编码拼接起来的对话历史,获得上下文表征 H c t x H_{ctx} Hctx。然后通过检索外部知识,通过一个知识编码器获得知识表征: H k H_k Hk。
由于基本的对话模型仅将医疗对话历史视为话语序列,难以对不同话语之间的多种医疗因果关系进行建模,这些复杂关系隐含了诱导下一步反应的关键医疗信息。为了应对这个问题,本文提出了KDGE,构建一个对话知识图,用图注意力网络来编码图。
首先,把对话历史转换成图,每句话语被视为一个结点。结点之间有两种类型的边,一种是常规的时序边,另一种是知识感知边,这种边把零散的话语通过医疗关系连接起来。这些知识感知边将来自外部医学知识图谱的医学知识融入对话中,使模型能够表示话语之间的复杂医疗关系。
具体而言,我们首先从每句话中抽取医学实体,然后到外部知识图谱CMeKG中查找他们的关系。如果来自两个话语的医学实体之间存在某种关系,我们在两个话语之间添加一条知识感知的边。
构建好知识感知的对话图后,我们使用Relational Graph Attention Network (RGAT)来编码这些对话中的关键关系信息。Relational Graph Attention Networks(2019)。
话语的句子级表征:对于图 G G G中的每个结点 v i v_i vi,我们使用一个基于transformer的编码器来编码其对应的话语,然后计算其平均的单词表征来获取句子级的表征(实际上就是MeanPooling): h i h_i hi。
话语在对话结构图中的结点表征:然后这个句子级的表征和它对应的speaker表征(speaker embedding)拼接起来,形成其初始的结点嵌入: v i 0 v_i^0 vi0。
最后,使用RGAT来更新结点的表示:
( v 1 , . . . , v M ) = R G A T ( ( v 1 0 , . . . , v M 0 ) , G ) (v_1,...,v_M) = RGAT((v_1^0,...,v_M^0),G) (v1,...,vM)=RGAT((v10,...,vM0),G)
话语的召回得分:为了进行对话召回,我们将上下文编码 H c t x H_{ctx} Hctx作为初始的对话历史表征,然后定义一个召回得分 α v i \alpha_{v_i} αvi作为话语 X i X_i Xi在召回过程中的重要性度量:
α v i = σ ( ( W v q h c t x ) T ( W v k v i ) ) \alpha_{v_i} = \sigma((W_v^qh_{ctx})^T(W_v^kv_i)) αvi=σ((Wvqhctx)T(Wvkvi))
这里的 h c t x h_{ctx} hctx是全部拼接对话历史经过BERT编码后的 H c t x H_{ctx} Hctx经过mean-pooled得到的。
这里像点积注意力机制,上下文作为查询,单句话(结点)的图表征结果作为键,计算注意力得分。
话语的最终结构编码:那么,最终 X i X_i Xi的结构编码通过句子编码和图结点编码和召回得分加权得到:
h s t c , i = α v i ( h i + v i ) h_{stc,i} = \alpha_{v_i}(h_i+v_i) hstc,i=αvi(hi+vi)
对话历史的结构编码:最后把全部句子的结构表征拼接起来,就得到了对话的结构编码: H s t c H_{stc} Hstc。
总结一下,如何获取对话历史的结构编码?首先,对对话历史的每句话编码拿到句子级表征;然后对对话图中的结点进行更新表示,拿到结点级表征;随后,计算每句话的召回得分,对话历史上下文做Query,结点表征做Key,然后进行加权得到话语的结构表征。最后,每句话的结构编码拼接得到对话历史的结构编码。
在base模型中,生成模型首先进行自注意力,在每个时间步生成解码状态,然后利用上下文表征 H c t x H_{ctx} Hctx和知识表征 H k H_k Hk进行交叉注意力。这种模型在训练回复生成的过程中,通常难以建模长对话历史、关注其中的关键信息。
本文提出REG,在生成回复之前,显示地生成关键信息 R R R, R R R是一个包含了对话历史关键信息的简洁的摘要。生成 R R R之后,模型将生成后续的回复:
y t = R E G ( H c t x , H k , H s t c , [ R ; y y < t ] ) y_t = REG(H_{ctx},H_k,H_{stc},[R;y_{y
训练时, R R R由医疗预训练模型PCL-MedBERT自动生成,作为一个训练模型召回关键信息的监督信号。推理时,MedPIR首先产生召回信息,然后生成回复,即脱离了PCL-MedBERT模块的监督信号。这种方式有两个优点:
预生成的 R R R为生成器通过自注意力访问关键历史信息提供了捷径。
强化了交叉注意力机制来关注编码器提供的关键信息。(没懂,这里是因为 R R R通过自注意力生成的表征,在下一步通过交叉注意力和编码器交互?)
如图所示,生成器首先生成召回信息 R R R,然后是一个分隔符[RSEP],注意,我们使用知识编码 H k H_k Hk的平均池化来作为分隔符的嵌入,以此来驱动生成过程中的知识融合。
为聚合编码器的不同类型的信息,在自注意力(SA)和LN模块之后,引入一个聚合模块:Fusion(·),这是一个门控机制,连接上下文编码 H c t x H_{ctx} Hctx、结构编码 H s t c H_{stc} Hstc、知识编码 H k H_k Hk。
首先,SA和LN的输出 h S , t l h_{S,t}^l hS,tl作为Query,三种信息作为Key进行交叉注意力。然后,通过一个线性层和sigmoid获得一个门控得分,紧接着三种得分进行softmax归一化,最后进行加权求和:
最终生成时:
在进行信息召回和回复生成时,门控聚合网络动态控制上下文编码、结构编码、知识编码的信息流。结构编码为上下文编码提供了一个补充信息,促进了关键信息召回的生成。
由于医疗对话摘要语料库的缺乏,本文使用PCL-MedBERT进行抽取式摘要,选择和目标回复最相关的几个对话历史话语作为训练信号。PCL-MedBERT对每个话语和回复进行编码,然后计算他们的余弦相似度:
然后,选择相似度得分最高的K个话语,拼接起来,作为目标召回信息。这虽然是一个比较模糊的监督信号,但是抽取出的摘要话语通常包含生成一条信息性医疗回复需要的关键信息,如下图所示:
为了进一步促进模型在推理时生成符合条件的 R R R,通过引入一个监督召回得分的二元交叉熵损失,促进对关键话语的识别:
其中, r i ∈ { 0 , 1 } r_i \in \{0,1\} ri∈{0,1},表示是否属于召回的关键话语。召回得分越高,表示相应话语对应召回的重要性越高。
最终优化目标,召回信息和回复生成是两个独立的子损失:
MedDG的研究工作表明,预测回复中可能出现的医疗实体对生成有信息性的医疗实体很有帮助,因此,本文训练自己的知识检索模型来检索可能在回复中出现的医疗实体。
首先,以对话历史中出现的医疗实体为中心节点,在CMeKG中选择具有一跳关系的子图。然后,我们只检索子图中的实体。我们采用两个独立的PCL-MedBERT分别编码对话历史 X X X和和实体 E E E(包含几个token),得到 h X h_X hX和 h E h_E hE然后取[CLS]作为编码器输出。二者的内积,即为该实体的检索得分。令 E i + E_i^+ Ei+为目标回复中出现的正实体, { E j − } j = 1 n \{E^-_j\}^n_{j=1} {Ej−}j=1n为 n n n个没出现的负实体,优化如下损失来训练检索器:
我们检索对话历史中的top-20个实体,然后使用另一个PCL-MedBERT作为知识编码器,将检索到的实体根据检索得分排序,然后用[SEP]拼接成一个序列,送入到PCL-MedBERT编码得到 H k H_k Hk。