最近刚好在看BAMnet这篇做KBQA的实验,顺带把记忆网络的几篇经典文章看了一下做一下总结。另外就是 Facebook 7月份刚上传的一篇用 memory 来改进 BERT 结构的文章 Large Memory Layers with Product Keys 在精度不降的前提下提高了 BERT 的效率,看来脸书是真的喜欢 Memory这个概念的运用拓展。
MEMORY NETWORKS ICLR2015
记忆网络的提出主要是解决 RNN 的隐层无法保存长期记忆、能传达的记忆内容太少的问题。这里就引入了一种外部记忆模块就和 RAM 一样能保存大量历史信息,与一些读写模块一起组成了 Memory Network,如下图:
由图中可以看出一共有四个模块:
I: input feature map
G: generalization
O: output
R: response
具体来说,I模块是 embedding lookup,将原始文本转化为词向量,G模块将输入的向量存储在memory数组的下一个位置,不做其他操作,对老的记忆不做修改。O模块根据输入的问题向量在所有的记忆中选择出 topk 相关的记忆,具体选择方式为,先选记忆中最相关的memory:
o 1 = O 1 ( x , m ) = arg max i = 1 , … , N s O ( x , m i ) o_{1}=O_{1}(x, \mathbf{m})=\underset{i=1, \ldots, N}{\arg \max } s_{O}\left(x, \mathbf{m}_{i}\right) o1=O1(x,m)=i=1,…,NargmaxsO(x,mi)
其中目标函数是用 Bilinear Regression 来建模问题 q q q 和记忆 m m m 的相关程度:
s ( x , y ) = Φ x ( x ) ⊤ U ⊤ U Φ y ( y ) s(x, y)=\Phi_{x}(x)^{\top} U^{\top} U \Phi_{y}(y) s(x,y)=Φx(x)⊤U⊤UΦy(y)
接下来根据选择出的 o 1 o_1 o1 和输入 x x x 一起选择与他们两个最相关的记忆 o 2 o_2 o2:
o 2 = O 2 ( x , m ) = arg max i = 1 , … , N s O ( [ x , m o 1 ] , m i ) o_{2}=O_{2}(x, \mathbf{m})=\underset{i=1, \ldots, N}{\arg \max } s_{O}\left(\left[x, \mathbf{m}_{o_{1}}\right], \mathbf{m}_{i}\right) o2=O2(x,m)=i=1,…,NargmaxsO([x,mo1],mi)
然后一直迭代下去,就这样选择出于Question最相关的 topk 个 memory slot 。将其作为R模块的输入,用于生成最终的答案。其实这里也很简单就是使用与上面相同的评分函数计算所有候选词与R输入的相关性,得分最高的词语就作为正确答案输出即可:
r = argmax w ∈ W s R ( [ x , m o 1 , m o 2 ] , w ) r=\operatorname{argmax}_{w \in W} s_{R}\left(\left[x, \mathbf{m}_{o_{1}}, \mathbf{m}_{o_{2}}\right], w\right) r=argmaxw∈WsR([x,mo1,mo2],w)
如果数据集是输出一句话的,那么最后可以用RNNLM来生成带有回复信息的句子。
最后它定义的损失函数(margin ranking loss)如下,设 k = 2 k=2 k=2 :
∑ f ‾ ≠ m o 1 max ( 0 , γ − s O ( x , m o 1 ) + s O ( x , f ‾ ) ) + ∑ f ‾ ′ ≠ m o 2 max ( 0 , γ − s O ( [ x , m o 1 ] , m o 2 ] ) + s O ( [ x , m o 1 ] , f ‾ ′ ] ) ) + ∑ r ⃗ ≠ r max ( 0 , γ − s R ( [ x , m o 1 , m o 2 ] , r ) + s R ( [ x , m o 1 , m o 2 ] , r ‾ ] ) ) \begin{array}{c}{\sum_{\overline{f} \neq \mathbf{m}_{o_{1}}} \max \left(0, \gamma-s_{O}\left(x, \mathbf{m}_{o_{1}}\right)+s_{O}(x, \overline{f})\right)+} \\ {\sum_{\overline{f}^{\prime} \neq \mathbf{m}_{o_{2}}} \max \left(0, \gamma-s_{O}\left(\left[x, \mathbf{m}_{o_{1}}\right], \mathbf{m}_{o_{2}}\right]\right)+s_{O}\left(\left[x, \mathbf{m}_{o_{1}}\right], \overline{f}^{\prime}\right] ) )+} \\ {\sum_{\vec{r} \neq r} \max \left(0, \gamma-s_{R}\left(\left[x, \mathbf{m}_{o_{1}}, \mathbf{m}_{o_{2}}\right], r\right)+s_{R}\left(\left[x, \mathbf{m}_{o_{1}}, \mathbf{m}_{o_{2}}\right], \overline{r}\right]\right) )}\end{array} ∑f=mo1max(0,γ−sO(x,mo1)+sO(x,f))+∑f′=mo2max(0,γ−sO([x,mo1],mo2])+sO([x,mo1],f′]))+∑r=rmax(0,γ−sR([x,mo1,mo2],r)+sR([x,mo1,mo2],r]))
第一个意思是有没有挑选出正确的第一句话,第二个意思是正确挑选出了第一句话后能不能正确挑出第二句话,合起来就是能不能挑选出正确的语境,用来训练 attention 参数,第三个式子把正确的 supporting fact 作为输入,能不能挑选出正确的答案,来训练 response 参数。
总结
End to End MEMORY NETWORKS NIPS2015
上文的损失函数可以看出O和R模块承担了主要的任务并且都需要监督,我们需要知道O选择的supporting fact 是否正确,R生成的 response 是否正确。这篇其实就是用了soft attention 来估计每一个 supporting fact 的相关程度,实现了端到端的 BP 过程。论文中提出了单层和多层两种架构,多层其实就是将单层网络进行stack。
单层结构如下:
其中A,B,C三个矩阵就是 embedding 矩阵,主要是将输入文本和 Question 编码成词向量,W是最终的输出矩阵。
1、输入模块的主要作用是将输入的文本转化成向量并保存在memory中,本文中的方法是将每句话压缩成一个向量对应到memory中的一个slot(上图中的蓝色或者黄色竖条)。其实就是根据一句话中各单词的词向量得到句向量。论文中提出了两种编码方式,BoW和位置编码。BoW就是直接将一个句子中所有单词的词向量求和表示成一个向量的形式,这种方法的缺点就是将丢失一句话中的词序关系,进而丢失语义信息;而位置编码的方法,不同位置的单词的权重是不一样的,然后对各个单词的词向量按照不同位置权重进行加权求和得到句子表示。位置编码公式如下:
l k j = ( 1 − j / J ) − ( k / d ) ( 1 − 2 j / J ) l_{k j}=(1-j / J)-(k / d)(1-2 j / J) lkj=(1−j/J)−(k/d)(1−2j/J)
m i = ∑ j l j ⋅ A x i j m_{i}=\sum_{j} l_{j} \cdot A x_{i j} mi=j∑lj⋅Axij
另外为了编码时序信息,我们需要在上面得到 m i m_i mi 的基础上再加上个矩阵对应每句话出现的顺序,不过这里是按反序进行索引。将该时序信息编码在 T a T_a Ta 和 T c T_c Tc 两个矩阵里面,所以最终每句话对应的记忆mi的表达式如下所示:
m i = ∑ j l j ⋅ A x i j + T A ( i ) m_{i}=\sum_{j} l_{j} \cdot A x_{i j}+T_{A}(i) mi=j∑lj⋅Axij+TA(i)
2、上面的 Input 模块可以将输入文本编码为向量的形式并保存在 memory 中,这里分为两个部分,一个用于跟 Question 相互作用得到各个 memory slot 与问题的相关程度,另一个则使用该信息产生输出。
3、输出模块根据Question产生了各个memory slot的加权求和,也就是记忆中有关Question的相关知识,Response模块主要是根据这些信息产生最终的答案。其结合o和q两个向量的和与W相乘在经过一个softmax函数产生各个单词是答案的概率,值最高的单词就是答案。并且使用交叉熵损失函数最为目标函数进行训练。
4、多层结构(K hops)也很简单,相当于做多次 addressing、 attention,每次 focus 不同的 memory 上,不过在第 k+1 次 attention 时 query 的表示需要把之前的 context vector 和 query 拼起来,其他过程几乎不变,也就是说上面几层的输入就是下层o和u的和 u k + 1 = u k + o k u^{k+1}=u^{k}+o^{k} uk+1=uk+ok 。最后在顶层输出时就是:
a ^ = Softmax ( W u K + 1 ) = Softmax ( W ( o K + u K ) ) \hat{a}=\operatorname{Softmax}\left(W u^{K+1}\right)=\operatorname{Softmax}\left(W\left(o^{K}+u^{K}\right)\right) a^=Softmax(WuK+1)=Softmax(W(oK+uK))
至于各层的参数选择,论文中提出了两种方法来减少参数量,而且如果每层参数都不同的话会导致参数很多难以训练:
总结
由上图的 3-hop 的实验结果可以看出这种记忆网络的推理效果还是有点成型了但是效果还是不理想,它只是简单的把context线性变换成了一个整体的memory,为了在对话中引入更多的外部知识,我们就引出了下面要说的 key-value MemNN。
Key-Value Memory Networks for Directly Reading Documents EMNLP2016
其实看到key-value我第一反应就是之前看的self-attention,看完论文个人感觉其实还是挺相近的,key做寻址value做后续的加权求和。这里的KV-MemNN将memory存入(key,value)键值对,并且引入了Wiki、KB、IE三种知识库,整体框架如下:
Key hashing:根据输入的问题从知识源中用倒排索引检索出与问题相关的facts存入memory,从而减小后续的进一步匹配数据量
Key addressing:利用hashing的得到的 candidate memories 去和 query 线性变换后的结果计算一个相关概率:
p h i = Softmax ( A Φ X ( x ) ⋅ A Φ K ( k h i ) ) p_{h_{i}}=\operatorname{Softmax}\left(A \Phi_{X}(x) \cdot A \Phi_{K}\left(k_{h_{i}}\right)\right) phi=Softmax(AΦX(x)⋅AΦK(khi))
Value Reading:得到相关概率后对 value 进行加权求和得到输出向量 o o o :
o = ∑ i p h i A Φ V ( v h i ) o=\sum_{i} p_{h_{i}} A \Phi_{V}\left(v_{h_{i}}\right) o=i∑phiAΦV(vhi)
这样就完成了一个hop的操作,接下来将输出向量o与输入问题的向量表示q相加 ,经过Ri矩阵进行映射,在作为下一层的输入 q 2 = R 1 ( q + o ) q_{2}=R_{1}(q+o) q2=R1(q+o),相关概率也随之更新 p h i = Softmax ( q j + 1 ⊤ A Φ K ( k h i ) ) p_{h_{i}}=\operatorname{Softmax}\left(q_{j+1}^{\top} A \Phi_{K}\left(k_{h_{i}}\right)\right) phi=Softmax(qj+1⊤AΦK(khi)) 重复循环这个过程即可。最后在答案预测如下,其中 B Φ Y ( y i ) B \Phi_{Y}\left(y_{i}\right) BΦY(yi) 是对 candidates 的向量表示。
a ^ = argmax i = 1 , … , C Softmax ( q H + 1 ⊤ B Φ Y ( y i ) ) \hat{a}=\operatorname{argmax}_{i=1, \ldots, C} \operatorname{Softmax}\left(q_{H+1}^{\top} B \Phi_{Y}\left(y_{i}\right)\right) a^=argmaxi=1,…,CSoftmax(qH+1⊤BΦY(yi))
总结
总体来看其实和端到端的很像,但是引入了key-value对之后可以事先对外部知识编码,这样就不用更多的依赖模型训练的embedding而是在每次查询配对知识信息,使模型能找到更准确的记忆得到与答案更接近的输出。
Bidirectional Attentive Memory Networks for Question Answering over Knowledge Bases NAACL2019
之前看的记忆网络全是为了这篇19年做KBQA的做铺垫,也算是对KBQA的理解性试验,代码作者也已经开源,这里主要还是关注文章中如何利用MemNN解决关系检测的问题。下图就是整体框架:
乍一看这图还是挺复杂的,模块很多箭头也飞来飞去的,下面我会拆分开一步一步记录整个模型的流程。
1、Input module
这里使用BiLSTM对输入问句的 word embedding做编码得到 H Q H^Q HQ
2、Memory module
首先得到候选实体(答案) { A i } i = 1 ∣ A ∣ \left\{A_{i}\right\}_{i=1}^{|A|} { Ai}i=1∣A∣ 并对其做三种信息的编码(实体候选文中用的也是别人的方法,这里就不介绍了),三种信息可以结合下图来看:
Key-value memory module使用了一个 key-value memory network来存储候选答案。将以上三种编码信息按 d ∗ 3 d*3 d∗3 的形式分别存储到(key,value)中。
3、Reasoning module
整个框架的核心部分就是这个推理模块。
KB-aware attention module
对 H Q H^Q HQ 做self-attention后 A Q Q = softmax ( ( H Q ) T H Q ) \mathbf{A}^{Q Q}=\operatorname{softmax}\left(\left(\mathbf{H}^{Q}\right)^{T} \mathbf{H}^{Q}\right) AQQ=softmax((HQ)THQ) 再用BiLSTM编码得到 question vector: q = BiLSTM ( [ H Q A Q Q T , H Q ] ) \mathbf{q}=\operatorname{BiLSTM}\left(\left[\mathbf{H}^{Q} \mathbf{A}^{Q Q^{T}}, \mathbf{H}^{Q}\right]\right) q=BiLSTM([HQAQQT,HQ]) 。后面就像一个 Multi-head Attention一样拼接三个信息流得到 KB summary : m = [ m t ; m p ; m c ] \mathbf{m}=\left[\mathbf{m}_{t} ; \mathbf{m}_{p} ; \mathbf{m}_{c}\right] m=[mt;mp;mc] ,将其与 H Q H^Q HQ 相乘得到 q q q 中每个单词 q i q_i qi 与KB信息的相关性,用maxpool、softmax得到 a ~ Q \tilde{\mathbf{a}}^{Q} a~Q,他代表的是问句中每个单词 q i q_i qi 对于 t y p e , p a t h , c o n t e x t type, path, context type,path,context 的权重分配。
Importance module
其中 A Q M A^{QM} AQM 建模了三种信息各自对于 q q q 的联系, A ~ M \tilde{\mathbf{A}}^{M} A~M 表示每种信息对于候选答案的重要程度。然后将权重赋予key值得到 *question-aware memory representations M ~ k \tilde{\mathbf{M}}^{k} M~k 。
Enhancing module
这个模块的式子写的很复杂,相当于在之前得到的互信息注意力机制权重的基础上对于原始的 q q q 和 KB信息做augmentation。对于 q q q 来说,标准化 A M Q = max k { A . . . , k Q M } k = 1 3 \mathbf{A}_{M}^{Q}=\max _{k}\left\{\mathbf{A}_{ . . ., k}^{Q M}\right\}_{k=1}^{3} AMQ=maxk{ A...,kQM}k=13得到 A ~ M Q \tilde{\mathbf{A}}_{M}^{Q} A~MQ 并把它结合到 question representation: H ~ Q = H Q + a ~ Q ⊙ ( A ~ M Q M ~ v ) T \tilde{\mathbf{H}}^{Q}=\mathbf{H}^{Q}+\tilde{\mathbf{a}}^{Q} \odot\left(\tilde{\mathbf{A}}_{M}^{Q} \tilde{\mathbf{M}}^{v}\right)^{T} H~Q=HQ+a~Q⊙(A~MQM~v)T,最终的 KB-enhanced question representation: q ~ = H ~ Q a ~ Q \tilde{\mathbf{q}}=\tilde{\mathbf{H}}^{Q} \tilde{\mathbf{a}}^{Q} q~=H~Qa~Q
同样的,对于KB来说,增强后的 question-enhanced KB representation M ‾ k \overline{\mathbf{M}}^{k} Mk:
M ‾ k = M ~ k + a ~ M ⊙ ( A ~ Q M ( H ~ Q ) T ) a ~ M = ( A ~ M Q ) T a ~ Q ∈ R ∣ A ∣ × 1 A ~ Q M = softmax ( A M Q T ) ∈ R ∣ A ∣ × ∣ Q ∣ \begin{aligned} \overline{\mathbf{M}}^{k} &=\tilde{\mathbf{M}}^{k}+\tilde{\mathbf{a}}^{M} \odot\left(\tilde{\mathbf{A}}_{Q}^{M}\left(\tilde{\mathbf{H}}^{Q}\right)^{T}\right) \\ \tilde{\mathbf{a}}^{M} &=\left(\tilde{\mathbf{A}}_{M}^{Q}\right)^{T} \tilde{\mathbf{a}}^{Q} \in \mathbb{R}^{|A| \times 1} \\ \tilde{\mathbf{A}}_{Q}^{M} &=\operatorname{softmax}\left(\mathbf{A}_{M}^{Q^{T}}\right) \in \mathbb{R}^{|A| \times|Q|} \end{aligned} Mka~MA~QM=M~k+a~M⊙(A~QM(H~Q)T)=(A~MQ)Ta~Q∈R∣A∣×1=softmax(AMQT)∈R∣A∣×∣Q∣
Generalization modul
最后的答案生成模块将上述的两个输出做attention、GRU,并用残差和batch mormalization得到最终的输出 q ^ \hat{\mathbf{q}} q^:
a = Att a d d G R U ( q ~ , M ‾ k ) \mathbf{a}=\operatorname{Att}_{\mathrm{add}}^{\mathrm{GRU}}\left(\tilde{\mathbf{q}},\overline{\mathbf{M}}^{k}\right) a=AttaddGRU(q~,Mk) m ~ = ∑ i = 1 ∣ A ∣ a i ⋅ M ~ i v \tilde{\mathbf{m}}=\sum_{i=1}^{|A|} a_{i} \cdot \tilde{\mathbf{M}}_{i}^{v} m~=i=1∑∣A∣ai⋅M~iv q ′ = GRU ( q ~ , m ~ ) \mathbf{q}^{\prime}=\operatorname{GRU}(\tilde{\mathbf{q}}, \tilde{\mathbf{m}}) q′=GRU(q~,m~) q ^ = B N ( q ~ + q ′ ) \hat{\mathbf{q}}=\mathrm{BN}\left(\tilde{\mathbf{q}}+\mathbf{q}^{\prime}\right) q^=BN(q~+q′)
4、Answer module
简单的目标函数: S ( q , a ) = q T ⋅ a S(\mathbf{q}, \mathbf{a})=\mathbf{q}^{T} \cdot \mathbf{a} S(q,a)=qT⋅a 计算 q ^ \hat{\mathrm{q}} q^ 和每个候选 answer 的匹配得分并排序。
损失函数还是基于 hinge loss:
ℓ ( y , y ^ ) = max ( 0 , 1 + y ^ − y ) \ell(y, \hat{y})=\max (0,1+\hat{y}-y) ℓ(y,y^)=max(0,1+y^−y) g ( q , M ) = ∑ a + ∈ A + ℓ ( S ( q , M a + ) , S ( q , M a − ) ) g(\mathbf{q},\mathbf{M})=\sum_{a^{+} \in A^{+}} \ell\left(S\left(\mathbf{q}, \mathbf{M}_{a^{+}}\right), S\left(\mathbf{q}, \mathbf{M}_{a^{-}}\right)\right) g(q,M)=a+∈A+∑ℓ(S(q,Ma+),S(q,Ma−)) o = g ( H Q a ~ Q , ∑ j = 1 3 M ⋅ , j k ) + g ( q ~ , M ‾ k ) + g ( q ^ , M ‾ k ) + g ( q w , H t 2 ) \begin{aligned} o=g\left(\mathbf{H}^{Q} \tilde{\mathbf{a}}^{Q}\right.&, \sum_{j=1}^{3} \mathbf{M}_{\cdot, j}^{k} )+g\left(\tilde{\mathbf{q}}, \overline{\mathbf{M}}^{k}\right) \\ &+g\left(\hat{\mathbf{q}}, \overline{\mathbf{M}}^{k}\right)+g\left(\mathbf{q}^{w}, \mathbf{H}^{t_{2}}\right) \end{aligned} o=g(HQa~Q,j=1∑3M⋅,jk)+g(q~,Mk)+g(q^,Mk)+g(qw,Ht2)
总结
这篇文章并没有拿MemNN做多跳的推理,因为训练测试集里对于每个entity数据都会挖掘2-hop以内answer的,其实还是有点失望的,这种是针对数据集本身的特性做出的而不能泛化到更复杂的数据集上,但是它里面的互注意力机制在对 q 和 KB 的建模起到十分关键的作用。
记忆网络能很好地针对QA任务中的Multi-hop特性,如阅读理解中的上下文推理、多篇章的答案抽取,如KBQA中的多关系多实体的问题,如对话系统中的状态跟踪、多轮对话,最近facebook还用memory嵌入BERT体系大大提高BERT的效率,更多运用个人认为还是可以继续跟进的。
另外因为是由KBQA引入的记忆网络,所以这里很多关于记忆网络其他的论文没有提及,还有 Gate MemNN、Dynamic Memory Networks 等论文可以做进一步研究,也有一些写的质量很好的博客可以参考:
知乎专栏-记忆网络