xlnet

Autoregressive和Autoencoding

假设输入序列为$(w_1,w_2,w_3,w_4,w_5)$。
autoregressive and autoencoding.png

Autoregressive

autoregressive模型使用前t-1个单词预测第t个单词,目标函数如下,
$$\max_{\theta} \log p_{\theta}(\mathbf{x})=\sum_{t=1}^{T}\log p_{\theta}(x_t|\mathbf{x}_{优点:

  1. =而非$\approx$,可以严格推到
  2. no discrependency
  3. 考虑被预测单词之间的依赖关系

缺点:

  1. 不能同时考虑双向信息(要么从左到右,要么从右到左)

Autoencoding

autoencoding模型,使用未被mask掉的单词预测被mask掉的单词,目标函数如下,
$$\max_{\theta} \log p_{\theta}(\overline{\mathbf{x}})|\hat{\mathbf{x}})\approx\sum_{t=1}^T m_t \log p_{\theta}(x_t | \hat{\mathbf{x}}) $$
其中,$\overline{\mathbf{x}}$表示句子中被mask掉的单词, $\hat{\mathbf{x}}$表示句子中没有被mask掉的单词。 $m_t$表示如果输入句子的第t个单词被mask掉,则为1,否则为0。

优点:

  1. 能同时考虑双向信息

缺点:

  1. Independent Assumption没有考虑被mask掉的单词的相关性(即单独地考虑被mask掉的单词);
  2. Training和Testing之间存在Discrependency(在训练时,输入有mask token,但是在预测时没有)。

两种模型取其精华,去其糟粕。

改进

使Autoregressive model可以同时考虑双向信息

具体做法是使用Permutation Language Model

对于长度为$T$的序列,有$T!$种排列。为了减少计算复杂度,对所有的permutation进行采样,最后计算期望。
$$\max_{\theta}\mathbb{E}_{z\sim \mathcal{Z}_T}[\sum_{t=1}^{T}\log p_{\theta}(x_t | \mathbf{x}_{其中,$\mathcal{Z}_T$表示所有可能的permutation。
训练时,还是使用一部分词去预测另一部分词,但是会考虑相关性。
在具体实现时,输入单词的顺序是不改变的,通过引入attention mask,起到permutation的效果。
permutation.png
图2
例如,对于一个输入$(x_1,x_2,x_3,x_4)$,对单词$x_3$进行预测,当permutation为$(x_3,x_2,x_4,x_1)$时,通过attention mask使得$h_3$只会考虑$mem$,即求$p_{\theta}(x_3|mem)$;当permutation为$(x_2,x_4,x_3,x_1)$时,$h_3$会考虑$mem$、$x_2$和$x_4$的信息,即求概率$p_{\theta}(mem,x_2,x_4)$。

target aware

假设使用标准的softmax来表示next-token的分布,即给定前t-1个单词,第t的单词的概率 $p_\theta(x_{z_t}|\mathbf{x}_{z_{$$p_\theta(x_{z_t}=x|\mathbf{x}_{z_{其中,$h_{\theta}(\mathbf{x}_{z_{存在两个不同的permutation $z^{(1)}$和$z^{(2)}$,满足$z^{(1)}_{为此,xlnet提出了target position aware next-token distribution。
$$p_\theta(x_{z_t}=x|\mathbf{x}_{z_{其中,$g_{\theta}(\mathbf{x}_{z_{

位置信息

接下来需要考虑如何定义$g_{\theta}(\mathbf{x}_{z_{

  1. 在预测单词$x_{z_t}$时,$g_{\theta}(\mathbf{x}_{z_{
  2. 当预测单词$x_{z_j},(j>t)$,$g_{\theta}(\mathbf{x}_{z_{

这就产生了矛盾,举例来说。例如,输入序列为$(x_1,x_2,x_3,x_4)$,需要对$x_2$和$x_4$进行预测,此时permutation为(3,2,4,1)。当预测$x_2$,希望模型只能知道$x_2$的位置信息但是不能知道$x_2$的内容,但是预测$x_4$又希望知道$x_2$的内容,即既不希望$g_2$包含$x_2$的内容,又希望$g_2$包含$x_2$的内容。这就产生了矛盾。
冲突.png
图3

Two-Stream Self Attention

xlnet的解决方法是,使用两种hidden representation。

  1. content representation $h_{\theta}(\mathbf{x}_{\leq t})$
  2. query representation $g_{\theta}(\mathbf{x}_{z_{

其中,content representation $h_{\theta}(\mathbf{x}_{\leq t})$与标准的transfomer中的hidden state一样,都编码了单词$x_{z_t}$自身的内容信息和其上下文的内容信息。query representation $g_{\theta}(\mathbf{x}_{z_{可以将这两种representation类比word2vec模型中的context embeding和target embedding。content representation $h_{\theta}(\mathbf{x}_{\leq t})$可以看作是单词$x_{z_t}$作为其它单词的上下文(context)时,使用的hidden state;query representation $g_{\theta}(\mathbf{x}_{z_{如图4(c)所示,在输入层,$h^{(0)}$为单词的embedding,$g^{(0)}$为随机初始化的向量。
query representation更新.png
图6
query representaion和hidden representation在前向传播中的计算方式为
$$g^{(m)}_{z_t}=Attention(Q=g^{(m-1)}_{z_t},KV=h^{(m-1)}_{z_{$$h^{(m)}_{z_t}=Attention(Q=h^{(m-1)}_{z_t},KV=h^{(m-1)}_{z_{\leq t}};\theta)$$
如图4(c)所示在计算query representaion$g^{(m)}_{z_t}$时不会使用到单词自身的内容信息,对应的attention mask,其主对角线上的值为0;在计算hidden representaion$h^{(m)}_{z_t}$时,会使用到单词自身的内容信息,对应的attention mask,其主对角线上的值不为0。

xlnet选择的Autoregressive Model时Transfomer-XL,引入Transfomer-XL的两大特性:relative positional encoding schemasegment recurrence mechanism
对于xlnet的输入,假定又两个segment,$\tilde{\mathbf{x}}$和$\mathbf{x}$,分别对两个segment进行permutation,并且前一个segment的hidden state会被后一个segment使用。
two stream attention with transformer-xl.png

训练采用partial prediction

xlnet通过使用permutation language modeling,使得模型可以同时考虑双向的信息(上下文),但是也带来了一些问题。(1)permuation数量大,可通过采样解决;(2)在一个permutation中位置靠前的单词,其上下文的长度是很短的,对这些单词进行预测,收益不大,因此xlnet在训练模型时,只对permutation中靠后的单词进行预测,即采用partial prediction的做法。
这样做的好处是,(1)靠后的单词,上下文长度较长,因此对样本的利用较为充分;(2)只对虑靠后的少量单词进行预测(假设为后$c$个的单词),因此没必要计算前$c$个单词的query representation(前$c$个单词全部充当的是上下文的角色,因此只需计算hidden representation),极大地减少了计算量。
目标函数为,
$$\max_{\theta}\mathbb{E}_{z\sim \mathcal{Z}_T}[\log p_{\theta}(\mathbf{x}_{z_{>c}}|\mathbf{x}_{z_{\leq c}})]=\mathbb{E}_{z\sim \mathcal{Z}_T}[\sum_{t=c+1}^{|z|}\log p_{\theta}(x_{z_t})|\mathbf{x}_{z_{

你可能感兴趣的:(神经网络,自然语言处理)