Star-Transformer
来自NAACL 2019的论文。
问题:
解决:
Star-Transformer用星型拓扑结构代替了全连通结构如上图左边是Transformer,而右边是Star-Transformer。在右边的图中,所有序列中直接相邻的词可以直接相互作用,而非直接相邻的元素则通过中心节点实现间接得信息传递,因此,复杂性从二次降低到线性,同时保留捕获局部成分和长期依赖关系的能力。
paper:https://arxiv.org/abs/1902.09113
code:https://github.com/fastnlp/fastNLP
Transformer-XL
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context,ACL 2019
问题:
那么如何学习更长语义联系?
segment-level Recurrence
segment-level循环机制。如上图左边为原始 Transformer,右边为 Transformer-XL,Transformer-XL 模型的计算当中加入绿色连线,使得当层的输入取决于本序列和上一个序列前一层的输出。这样每个序列计算后的隐状态会参与到下一个序列的计算当中,使得模型能够学习到跨序列的语义联系。(看动图可能更好理解)
h r n h^n_r hrn是第 r r r个segment的第n层隐向量,那么第r+1个的第n层的隐向量的计算,就是上面这套公式。
还有一点设计是,在评估预测模型的时候它是会连续计算前L个长度的隐向量的(训练的时候只有前一个,缓存在内存中)。即每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的token存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),这样能使跨语义更加的深入。
只看看XL多头注意力的forward的不同地方吧
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
#w是上一层的输出,r是相对位置嵌入(在下一节),r_w_bias是u,r_r_bias是v向量
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None: #mems就是前一些序列的向量,不为空
cat = torch.cat([mems, w], 0) #就拼起来
if self.pre_lnorm: #如果有正则化
w_heads = self.qkv_net(self.layer_norm(cat)) #这个net是nn.Linear,即qkv的变换矩阵W参数
else:
w_heads = self.qkv_net(cat)#没有正则就直接投影一下
r_head_k = self.r_net(r)#也是nn.Linear
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) #复制3份
w_head_q = w_head_q[-qlen:] #q的W不要拼接的mems
else:#没有mems,就正常的计算
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
#qlen是序列长度,bsz是batch size,n_head是注意力头数,d_head是每个头的隐层维度
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
####计算注意力的四个部分
#AC是指相对位置的公式里的a和c两个部分,相对位置在下一节做笔记
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
#爱因斯坦简记法求和sum,统一的方式表示各种各样的张量运算
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
#BD是指相对位置的公式里的b和d两个部分
rr_head_q = w_head_q + r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD #最后的结果
attn_score.mul_(self.scale)#进行放缩
Relative Position Encodings
相对位置编码。原始 Transformer 采用了正弦/余弦函数来编码绝对位置信息。然而因为Transformer-XL 会有多个句子,所以还是绝对位置,那么两个句子的相同位置是同样的编码,比如 [0, 1, 2, 3] 在两个句子concat之后就变成了[0, 1, 2, 3, 0, 1, 2, 3],句子不连续,而且每次拼的句子会不一样,也不能找到适合的绝对位置编码。所以这里使用相对位置编码。
上图是原始Transformer和Transformer-XL的比较,其中 E 表示词的 Embedding,而 U 表示绝对位置编码。这大一堆看起来奇奇怪怪,实际上Transformer的注意力计算是 ( E x i + U i ) W q W k ( E x j + U j ) (E_{x_i}+U_i)W_qW_k(E_{x_j}+U_j) (Exi+Ui)WqWk(Exj+Uj)的分解,即先编码Q(当前词 i)和K(其他的词 j)然后算内积,位置编码是直接add在词嵌入上面的。
而Transformer-XL的改变是:
细细思考,这attention的四个部分各有玄机:
相对位置编码的代码为:
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb #编码维度
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) #间隔频率
def forward(self, pos_seq):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq) #序列的位置向量 operation 间隔
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) #正弦余弦
return pos_emb[:,None,:] #直接返回R,非学习矩阵R
简单把编码维度设置为10,查询向量也是10个,存储之前的序列也是10,有以下结果:
>>> import torch
>>> inv_freq = 1 / (10000 ** (torch.arange(0.0, 10, 2.0) / 10))
>>> inv_freq
tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])
>>> pos_seq=torch.arange(20-1, -1, -1.0) #qlen+mlen,即10+10的维度然后逆序
>>> pos_seq
tensor([19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6.,
5., 4., 3., 2., 1., 0.])
>>> sinusoid_inp = torch.ger(pos_seq,inv_freq)
>>> sinusoid_inp
tensor([[1.9000e+01, 3.0113e+00, 4.7726e-01, 7.5640e-02, 1.1988e-02],
[1.8000e+01, 2.8528e+00, 4.5214e-01, 7.1659e-02, 1.1357e-02],
[1.7000e+01, 2.6943e+00, 4.2702e-01, 6.7678e-02, 1.0726e-02],
[1.6000e+01, 2.5358e+00, 4.0190e-01, 6.3697e-02, 1.0095e-02],
[1.5000e+01, 2.3773e+00, 3.7678e-01, 5.9716e-02, 9.4644e-03],
[1.4000e+01, 2.2189e+00, 3.5166e-01, 5.5735e-02, 8.8334e-03],
[1.3000e+01, 2.0604e+00, 3.2655e-01, 5.1754e-02, 8.2024e-03],
[1.2000e+01, 1.9019e+00, 3.0143e-01, 4.7773e-02, 7.5715e-03],
[1.1000e+01, 1.7434e+00, 2.7631e-01, 4.3792e-02, 6.9405e-03],
[1.0000e+01, 1.5849e+00, 2.5119e-01, 3.9811e-02, 6.3096e-03],
[9.0000e+00, 1.4264e+00, 2.2607e-01, 3.5830e-02, 5.6786e-03],
[8.0000e+00, 1.2679e+00, 2.0095e-01, 3.1849e-02, 5.0477e-03],
[7.0000e+00, 1.1094e+00, 1.7583e-01, 2.7867e-02, 4.4167e-03],
[6.0000e+00, 9.5094e-01, 1.5071e-01, 2.3886e-02, 3.7857e-03],
[5.0000e+00, 7.9245e-01, 1.2559e-01, 1.9905e-02, 3.1548e-03],
[4.0000e+00, 6.3396e-01, 1.0048e-01, 1.5924e-02, 2.5238e-03],
[3.0000e+00, 4.7547e-01, 7.5357e-02, 1.1943e-02, 1.8929e-03],
[2.0000e+00, 3.1698e-01, 5.0238e-02, 7.9621e-03, 1.2619e-03],
[1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])
>>> sinusoid_inp.sin()
tensor([[ 1.4988e-01, 1.2993e-01, 4.5935e-01, 7.5568e-02, 1.1988e-02],
[-7.5099e-01, 2.8479e-01, 4.3689e-01, 7.1598e-02, 1.1357e-02],
[-9.6140e-01, 4.3251e-01, 4.1416e-01, 6.7627e-02, 1.0726e-02],
[-2.8790e-01, 5.6939e-01, 3.9117e-01, 6.3654e-02, 1.0095e-02],
[ 6.5029e-01, 6.9200e-01, 3.6793e-01, 5.9681e-02, 9.4642e-03],
[ 9.9061e-01, 7.9726e-01, 3.4446e-01, 5.5706e-02, 8.8333e-03],
[ 4.2017e-01, 8.8254e-01, 3.2077e-01, 5.1731e-02, 8.2024e-03],
[-5.3657e-01, 9.4569e-01, 2.9688e-01, 4.7755e-02, 7.5714e-03],
[-9.9999e-01, 9.8514e-01, 2.7281e-01, 4.3778e-02, 6.9405e-03],
[-5.4402e-01, 9.9990e-01, 2.4856e-01, 3.9800e-02, 6.3095e-03],
[ 4.1212e-01, 9.8959e-01, 2.2415e-01, 3.5822e-02, 5.6786e-03],
[ 9.8936e-01, 9.5448e-01, 1.9960e-01, 3.1843e-02, 5.0476e-03],
[ 6.5699e-01, 8.9544e-01, 1.7493e-01, 2.7864e-02, 4.4167e-03],
[-2.7942e-01, 8.1396e-01, 1.5014e-01, 2.3884e-02, 3.7857e-03],
[-9.5892e-01, 7.1207e-01, 1.2526e-01, 1.9904e-02, 3.1548e-03],
[-7.5680e-01, 5.9234e-01, 1.0031e-01, 1.5924e-02, 2.5238e-03],
[ 1.4112e-01, 4.5775e-01, 7.5285e-02, 1.1943e-02, 1.8929e-03],
[ 9.0930e-01, 3.1170e-01, 5.0217e-02, 7.9621e-03, 1.2619e-03],
[ 8.4147e-01, 1.5783e-01, 2.5116e-02, 3.9811e-03, 6.3096e-04],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]])
使用Transformer-XL的预训练模型经典的就是XLNet啦,可以更好的处理较长的文本。
作者的论文和代码开源:
paper:https://arxiv.org/abs/1901.02860
code:https://github.com/kimiyoung/transformer-xl