Breaking the Softmax Bottleneck:A High-rank RNN Language Model

Content

此篇论文主要完成了:
1.通过数学推导,找到并证明了限制RNN-based LMs的性能瓶颈之一——Softmax Bottleneck问题
2.针对这个瓶颈,提出了一个解决方案—— Mixture of Softmaxes

Introduction

Language Modeling Problem

LM问题在指已知了一个符号(token)序列:

的情况下,生成一个Language模型来模拟这个序列出现的概率,即求解P(X)的值:

而根据链式法则(chain rule)和马尔科夫假设(Markov Assumption), P(X)的值可以通过求解它对应的联合概率得出:

: 下一个符号的概率分布
:历史序列 / 已经出现的所有Tokens

因此,原始的LM问题就转换成了:在每个时刻t,根据已知的符号序列(History), 求解下一时刻可能输出的符号的概率。 由于输出符号有很多种而可能性,所以这个概率的其实是一个在Vocabulary(或Token Set)上的概率分布。

Standard Approach: RNN based LMs

由于符号序列自带时间属性,而我们需要模拟的也是符号之间的时间依赖关系。因此对于LM问题来说,最标准的,而且state of art 模型都是基于RNN的。以下为一个基于RNN的Language Model的结构简图:

RNN based LMs

其中:



首先图片左下角的符号序列 " the cat sat on the " 是我们在此时刻 t 已知的历史序列,即
。由于输入的每个Token是由one-hot编码表示的,当Vocabulary很大的情况下,这个输入维度会非常的高。因此在处理这种高维输入时,会先使用word embedding matrix ( 图中矩阵U ) 来降低维度&学习词语的内部联系,使输入更有意义。
之后,经过处理的输入会被传送给RNN。基于此时刻 t 的输入和上一时刻RNN的旧隐藏状态
, RNN会产生新的隐藏状态

此隐藏状态可能看作是一个由RNN学习到的,含有下一时刻输出符号的信息的特征。由于输出是一个基于Vocabulary的概率分布,因此我们必须把学习到的这个特征映射回初始的Vocabulary。上图中,
下的output embedding matrix (W) 就负责这个反映射。应用上,两个embedding 矩阵 U和W是一样的。

Hypothesis&Main Issues

由Anton Maximilian Schäfer and Hans Georg Zimmermann写的Recurrent Neural Networks Are Universal Approximators论文可以得知,RNN的表达力是很强的,它可以模拟逼近任意的非线性动态系统(Universal approximation theorem)。由此作者推测出,基于RNN的LMs的性能瓶颈之一应该是RNN最后使用点乘+softmax操作,即:。

Mathematical Analysis of LM

Defination

为了能进行数学推导和定量分析来证明这个假设,首先我们需要一个自然语言的数学表达。自然语言L可以表示成N个元组的集合:

其中:
代表了语言中的任一个可能的context(history token序列)
真实的数据分布,即:已知一个历史符号序列(),下一符号在Token集合上的概率分布
代表了语言L中所有可能出现的符号
所有可能的上下文(符号组合)的数目
至此,LM问题可以转换成如下的数学公式表达:

即,给定一个自然语言L,LM需要学习一组参数,基于此组参数的模型可以逼近真实的任一上下文(context)所对应的下一符号概率分布。
若我们使用RNN-based LMs, 那么在network的输出端,我们能从softmax layer 的输出直接得到基于此时刻 t 的下一符号概率分布 :

因此,训练模型的Objective可以用以下等式表达:

即,我们使用一个RNN-based LM 来模拟每个可能context下的下一符号概率分布,并且不断优化模型使用的参数,使LM输出的概率分布逼近真实分布。

Matrix Factorization Problem

在数学化表达LM问题后,它的Objective公式还可以通过矩阵分解来做进一步的分析。
在的表达式中,代表了输入是不同的context(历史序列)的情况下,RNN所对应的不同隐藏状态。此处,可以把所有可能的情况列出,排列组合成一个矩阵:

这个矩阵包含了RNN针对不同的Context的所有可能的隐藏状态。根据此节开篇的假设,自然语言L一共有种可能的context(即:符号组合序列)。
相似地,公式中的也可以统一成矩阵表达:

中的每一行代表了语言L中的某一个符号所对应的embedding coefficient,用以把RNN学到的隐藏状态映射回包含符号集的Vocabulary空间。同样,根据此节开篇的假设,自然语言L一共有种可能的符号(tokens)。
最后,我们还需要把自然语言L真实的条件概率分布(在各种可能的context下,下一符号的概率分布)用矩阵的方式表达,从而能使用矩阵知识,数学地分析RNN-based LMs。此处假设矩阵代表了真实条件概率分布取后的结果:
A= \left[ \begin{matrix} \log{P^*(x_1|c_1)} &\log{P^*(x_2|c_1)}&...&\log{P^*(x_M|c_1)}\\ \log{P^*(x_1|c_2)} &\log{P^*(x_2|c_2)}&...&\log{P^*(x_M|c_2)} \\ ...&...&...&... \\ \log{P^*(x_1|c_N)} &\log{P^*(x_2|c_N)}&...&\log{P^*(x_M|c_N)} \end{matrix} \right]
由上公式可知,包含了context与对应next token的所有可能的组合。

Rank Analysis

在经历了上述对Objective的分析及矩阵转换,RNN-based LM问题事实上可以抽象如下:

即,通过学习,我们希望找到一组参数,以它为参数的LM模型(即RNN)可以逼近真实的下一符号概率分布的。
为了能推导出Softmax存在的瓶颈,首先先要引入一个矩阵操作。对一个矩阵 A 进行操作,其结果为一个矩阵集合

其中:
维度对应的全1矩阵
对角线元素值任意的对角线矩阵
事实上,的作用是把矩阵中的每行元素上加上任意一个实数,例如如下与矩阵相加后,的第 i 行会被加上一个实数。
\left[ \begin{matrix} a_1&0&0 \\ 0&a_2&0 \\ 0&0&a_3 \end{matrix} \right]_{\Lambda^{3\times3} } \times{\left[ \begin{matrix} 1&1&1 \\ 1&1&1 \\ 1&1&1 \end{matrix} \right]_{J^{3\times4}}}={\left[ \begin{matrix} a_1&a_1&a_1 \\ a_2&a_2&a_2 \\ a_3&a_3&a_3 \end{matrix} \right]}
而代表真实下一符号概率分布的 矩阵 与 经由 所得到的矩阵集合 ,有如下两个特殊性质:
1.所有真实数据分布所对应的logits都包含在了集合中。

  1. 中的所有矩阵的秩
    都相似,相差不大于1。
    附--矩阵的秩 :
    -定义: 矩阵中所有线性独立的列的数目和
    -直观解释:如果一个矩阵有着更高的秩,那么说明它有更多的线性独立的列。若把这些列看作是一组 basis vectors ,那么它们所能表达的空间就更复杂,表达能力就更强。即,高秩的矩阵能包含更多的信息量。
    -例子:如果我们把某自然语言L表示成矩阵形式(如上节中的矩阵),那么此矩阵天然拥有高秩的性质,例如:
    -它是高度依赖上下文的——“南”后面的符号可以是“京”或者“瓜”,取决于前后文是关于地理的还是农业的。即,在不同的上下文里,下一符号的概率分布会非常不同。
    -并且我们不可能找到一组有限数目的basis vectors,使用此基来表达语言L中的所有Token的关系。

review

由RNN-based LM的结构推导出,它的Objective如下:

通过把自然语言表达成矩阵形式,再进行矩阵分解(Matrix Factorization ),LM的目标可以抽象成如下表达。即,LM需要找到一组参数,借由这组参数生成的下一符号概率能无限逼近真实概率:

而通过引入矩阵运算符 row-wise shift ,以及此运算产生的矩阵集F(A)的第一个性质,我们可以推出,若RNN-based LM真的能逼近真实概率分布,那么它产生的 logits 必定属于真实概率分布矩阵 Arow-wise shift 结果集合中。即,Objective为如下:

Problem: Softmax Bottleneck

至此,LM问题的核心变成了研究是否真的存在一组参数使基于此的LM所产生的logits属于 ,如下:

回忆一下,如上公式中:
代表了所有可能的context输入下的对应隐藏状态。
代表了语言中所有可能的token所对应embedding coefficient
因此,由线性代数的知识可知,它们乘积的秩应该小于d,即:

(相较于自然语言中的context数目N和token数目M,embedding size d显然会小很多)
又由于row-wise shift的第二个性质(即:中的所有矩阵的秩都相似,相差不大与1)可推导出,若embedding size d有:

则对应的RNN-based LM 产生的logits不可能属于。换句话说,此LM不可能找到一组参数,使其能recover真实概率分布A
到底embedding size d能否满足上述不等式呢?我们已知,真实概率分布矩阵A也属于F(A),而且它是高秩的矩阵,其秩最大能和它的context数目相当()。而embedding本就是为了精简输入维度而使用的,所以它的维度一般会较小()。所以显然成立:

即,RNN-based LM 不可能找到一组参数 ,使其能recover真实概率分布 A。它只是一个真实概率分布的低秩近似,表达能力不够,因此失去了一些模拟context间依赖性的能力。这也正是性能瓶颈所在。

Sloution for Softmax Bottleneck

Naive Solution

要解决这个瓶颈问题,一个最直观的方法就是提高embedding size d。但是这显然与embedding的目的不符。另一个方法是使用Ngram模型,来避免Softmax的使用。这两种方法都会使总参数数目急剧增加,容易导致过拟合,显然都不可取。

Mixture of Softmaxes

而另一种方法就是使用作者提出的 MoS(Mixture of Softmaxes) 来替代原始的 Softmax 。MoS的公式如下:

由名字可知,Mos便是把多个Softmax按权相加,综合为一个Softmax混合模型。
传统的RNN-based LM的结构如下左图,而基于MoSRMM-LM 位于下图右。由比较可看出,仅在RNNhidden state 以后有所不同。

standard RNN vs. MoS

这两种不同的模型最后产生的下一符号概率分布的
也不同,如下:


这个优化版本由于引入了按权相加,因此在最后计算完
运算后,与模型产生的logits不再是原本的线性关系,理论上可以达到任意的高秩,因此提升了模型的表达能力。

Experiments

使用MoS的RNN与其他模型在LM问题上的表现对比如下:


result

Drawback

当然,MoS模型也有它的缺憾。由于使用了多个并行的Softmax按权相加,因此它的运算时间是原有模型的数倍。在实践中,其实Softmax Layer的计算是尤其费时的,因此这也算是不小的短板。由下图实验数据可知,MoS模型的计算时间与它所用的Softmax的数目K近似呈线性关系。

drawback

computational time / #softmax

Summary

现在普遍使用的RNN-based LM,由于在最后把RNN输出的隐藏状态乘以了output embedding matrix,并把得到的结果(logits)输入了softmax layer,导致最后整体模型所能模拟的概率分布空间的秩被embedding-size d 所限制。而MoS模型通过引入按权相加的运算打破了原来的线性关系,提高了模型模拟空间的秩。当然,其代价是线性增加的运算时间。

REFERENCES

[1]Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen. Breaking the Softmax Bottleneck: A High-Rank RNN Language Model. In ICLR 2018.
[2]Anton Maximilian Schäfer and Hans Georg Zimmermann. Recurrent neural networks are universal approximators. In International Conference on Artificial Neural Networks, pp. 632–640. Springer, 2006.
[3]Tomas Mikolov, Martin Karafiát, Lukas Burget, Jan Cernocky, and Sanjeev Khudanpur. Recurrent neural network based language model. In Interspeech, volume 2, pp. 3, 2010.
[4]Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182, 2017.

你可能感兴趣的:(Breaking the Softmax Bottleneck:A High-rank RNN Language Model)