[转载]记忆网络之End-To-End Memory Networks

记忆网络之End-To-End Memory Networks

这是Facebook AI在Memory networks之后提出的一个更加完善的模型,前文中我们已经说到,其I和G模块并未进行复杂操作,只是将原始文本进行向量化并保存,没有对输入文本进行适当的修改就直接保存为memory。而O和R模块承担了主要的任务,但是从最终的目标函数可以看出,在O和R部分都需要监督,也就是我们需要知道O选择的相关记忆是否正确,R生成的答案是否正确。这就限制了模型的推广,太多的地方需要监督,不太容易使用反向传播进行训练。因此,本文提出了一种end-to-end的模型,可以视为一种continuous form的Memory Network,而且需要更少的监督数据便可以进行训练。论文中提出了单层和多层两种架构,多层其实就是将单层网络进行stack。我们先来看一下单层模型的架构:

单层 Memory Networks

单层网络的结构如下图所示,主要包括下面几个模块:

模型主要的参数包括A,B,C,W四个矩阵,其中A,B,C三个矩阵就是embedding矩阵,主要是将输入文本和Question编码成词向量,W是最终的输出矩阵。从上图可以看出,对于输入的句子s分别会使用A和C进行编码得到Input和Output的记忆模块,Input用来跟Question编码得到的向量相乘得到每句话跟q的相关性,Output则与该相关性进行加权求和得到输出向量。然后再加上q并传入最终的输出层。接下来详细介绍一下各个模块的原理和实现(这里跟论文中的叙述方式不同,按照自己的理解进行介绍)。

输入模块

首先是输入模块(对应于Memory Networks那篇论文的I和G两个组件),这部分的主要作用是将输入的文本转化成向量并保存在memory中,本文中的方法是将每句话压缩成一个向量对应到memory中的一个slot(上图中的蓝色或者黄色竖条)。其实就是根据一句话中各单词的词向量得到句向量。论文中提出了两种编码方式,BoW和位置编码。BoW就是直接将一个句子中所有单词的词向量求和表示成一个向量的形式,这种方法的缺点就是将丢失一句话中的词序关系,进而丢失语义信息;而位置编码的方法,不同位置的单词的权重是不一样的,然后对各个单词的词向量按照不同位置权重进行加权求和得到句子表示。位置编码公式如下:lj就是位置信息向量(这部分可以参考我们后面的代码理解)。

此外,为了编码时序信息,比如Sam is in the bedroom after he is in the kitchen。我们需要在上面得到mi的基础上再加上个矩阵对应每句话出现的顺序,不过这里是按反序进行索引。将该时序信息编码在Ta和Tc两个矩阵里面,所以最终每句话对应的记忆mi的表达式如下所示:

输出模块

上面的输入模块可以将输入文本编码为向量的形式并保存在memory中,这里分为Input和Output两个模块,一个用于跟Question相互作用得到各个memory slot与问题的相关程度,另一个则使用该信息产生输出。

首先看第一部分,将Question经过输入模块编码成一个向量u,与mi维度相同,然后将其与每个mi点积得到两个向量的相似度,在通过一个softmax函数进行归一化:

pi就是q与mi的相关性指标。然后对Output中各个记忆ci按照pi进行加权求和即可得到模型的输出向量o。

Response模块

输出模块根据Question产生了各个memory slot的加权求和,也就是记忆中有关Question的相关知识,Response模块主要是根据这些信息产生最终的答案。其结合o和q两个向量的和与W相乘在经过一个softmax函数产生各个单词是答案的概率,值最高的单词就是答案。并且使用交叉熵损失函数最为目标函数进行训练。

多层模型

其实就是将多个单层模型进行stack在一块。这里成为hop。其结构图如下所示:

首先来讲,上面几层的输入就是下层o和u的和。至于各层的参数选择,论文中提出了两种方法(主要是为了减少参数量,如果每层参数都不同的话会导致参数很多难以训练)。

1. Adjacent:这种方法让相邻层之间的A=C。也就是说Ak+1=Ck,此外W等于顶层的C,B等于底层的A,这样就减少了一半的参数量。

2. Layer-wise(RNN-like):与RNN相似,采用完全共享参数的方法,即各层之间参数均相等。Ak=...=A2=A1,Ck=...=C2=C1。由于这样会大大的减少参数量导致模型效果变差,所以提出一种改进方法,即令uk+1=Huk+ok,也就是在每一层之间加一个线性映射矩阵H。

转载自知乎专栏:记忆网络-Memory Network

你可能感兴趣的:([转载]记忆网络之End-To-End Memory Networks)