架构级理解BERT——逃不掉的RNN

前言

写这一个系列的动因就是自己想深入了解一下BERT的原理。BERT是怎么被构想出来的?比较适合刚入门的小白阅读,读完之后会发现其实不过如此。那么既然是架构级的,本系列不会过多的涉及代码级的或者说公式级的,更多的是设计方式以及设计的原因。本系列将分成以下几个模块,

  1. 逃不掉的RNN
  2. 探求机翻的内幕:Seq2Seq [https://segmentfault.com/a/11...]
  3. Attention is ALL you need
  4. Transformer是谁?
  5. This is BERT

时序数据

回想我们知道的全连接层,他实际上做的就是将一input转换为一堆output,这些input之间没有时间上的关系,而是将所有input揉在一起输入到模型中。而对于时序数据,比如一段话,人在阅读的时候是单向的、随着时间将input逐个输入,并且很多时候input的长度是未知的。所以我们就需要一个新的架构来处理这一类的时序数据。

Simple RNN

为什么需要RNN?假设我们需要解决一个文本分类任务。

要想对文本进行分类,那么我们必须将这个文本数字化或者说向量化,并且需要保证这个向量能够表征这句话,包含了这句话中所有的特征。而通过RNN就能够将一个句子中的所有信息都融合起来,表征成一个向量。

上面就是RNN的架构图,我们一步一步来讲。首先是我们的输入,显然是一串文本:the cat sat ... mat。那么对于某个单词而言,比如the,想要参与后续的运算,首先需要将其向量化,这里可以用到word2vec等算法,通过语义的方式将单词转化为一个向量x0。然后将x0输入矩阵A,内部过程如下,

架构级理解BERT——逃不掉的RNN_第1张图片

将x0与ht-1(此处因为x0是第一个,ht-1理论上来说不存在,那么可以通过一些处理比如将其置零之类代替)连接起来,和矩阵A相乘,再进行tanh激活函数计算,得到h0。然后将h0和x1输入到矩阵A又重复上述计算,循环往复(所以RNN叫循环神经网络,上一步的输出又作为下一步的输入)。显然,h0中存储着x0的状态,h1中存储着x0和x1的状态,以此类推,最后一个输出ht中应该存储着前面所有xi(单词)的信息。这样,ht相当于征集了所有村民的意见,就可以作为代表拿去评估,完成二分类的任务。

此处思考一个问题:为什么需要tanh激活?

假设没有tanh,为方便讨论,我们假设所有的x都是零向量,那么ht ≈ A×ht-1。易得h100 = A100h0。易得Ax = λx,A100x = λ100x,若A的特征值λ稍大于1,那么A100就直接爆炸了;若λ稍小于1,那么A100估计就直接变成零矩阵了。所以需要tanh进行一个类似于正则化的工作。

缺点使人进步,RNN存在什么缺点呢?=>RNN的记忆力特别短,什么是记忆力?之前我们说“最后一个输出ht中应该存储着前面所有xi(单词)的信息”,确实,但是当序列长度变大,对于最前面的x所遗留下来的特征可能已经被覆盖掉了。可以通过计算h~100~关于x0的导数来判断之间的相关性,计算可得,导数接近于0。

架构级理解BERT——逃不掉的RNN_第2张图片

所以我们可以发现,simpleRNN在处理短距离的文本时效果较好,当序列长度变大之后,效果就不太好了。所以引出LSTM。

改造SimpleRNN => LSTM

可见LSTM其实也是一类RNN罢了,此处我不会详细解释其原理,因为Attention is ALL you need,RNN可以被attention取代了。

首先来看LSTM的架构,

架构级理解BERT——逃不掉的RNN_第3张图片

仔细观察发现,其实和简单RNN没有很大区别,其实就是把之前单纯的乘矩阵A转化为一系列更为复杂的操作。所以SimpleRNN可以被LSTM完全替换掉,就像替换某个零件一样。内部的大致过程如下。

为了弥补SimpleRNN 记忆力丢失的问题,LSTM将记忆放在一个传输带上,也就是下图中的Conveyor Belt,也就是左下方的C~t-1~,记忆的更新方式如图右下方所示,其中可以发现f向量,这是LSTM中的遗忘门,他可以控制哪些信息忽略,哪些信息保留。遗忘门也保证了长距离记忆始终存在。

架构级理解BERT——逃不掉的RNN_第4张图片

最后是LSTM的输出,在记忆信息C的基础上进行一定的加工之后得到ht。同样和SimpleRNN类似,将上一步的输出作为输入传入下一层。

架构级理解BERT——逃不掉的RNN_第5张图片

你可能感兴趣的:(nlp自然语言处理神经网络)