相对位置编码 relative position encoding

文章是对视频的总结原作者

Transformer中有两种常用的位置编码,分别为绝对位置编码和相对位置编码,其中绝对位置编码使用的比较多,也是比较简单的。在代码中直接初始化0矩阵,在forward函数中直接与x相加:

相对位置编码 relative position encoding_第1张图片

self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels))

def forward(self,x)

embeddings = x + self.position_embeddings

而相对位置编码,是根据token与token之间的位置关系来生成权重:

假如有5个token,其中一个token与其他所有位置包括自己在内的token之间存在一个权重。 

相对位置编码 relative position encoding_第2张图片

 w0表示x4与自己的位置关系,0表示与自己的距离,w1表示向右移动一个位置,w-1表示向左移动一个位置。

相对位置编码 relative position encoding_第3张图片

x3也是同理。

相对位置编码 relative position encoding_第4张图片

 这样第一个到最后一个就可以表示为:

相对位置编码 relative position encoding_第5张图片

 相对位置编码 relative position encoding_第6张图片

 一共有9个不同的位置编码,分别为w-4, w-3, w-2, w-1, w0, w1, w2, w3, w4。用图片表示为:

相对位置编码 relative position encoding_第7张图片

 我们用标识对表示为:

相对位置编码 relative position encoding_第8张图片

 作者使用了一个k阈值,当超过这个特定的阈值,即其他的position_embedding距离自身超过两个位置,其他的position_embedding就和距离最近的position_embedding值一样。

相对位置编码 relative position encoding_第9张图片

假设k=2,w3和w4就会变成w2,后面同理。 

相对位置编码 relative position encoding_第10张图片

那么如何用公式表示呢:

相对位置编码 relative position encoding_第11张图片

作者在text上进行了实验,但是词链就是一个展平了的graph,所以这种方法可以运用到图中,只要你元素中有成对的关系就可以使用。

在小数据集上,相比卷积transformer缺少内在偏置,因此需要加入相对位置编码来解决这个问题,而对于特大数据集,transformer就可以学习到卷积的内在偏置,即这也是为什么在小数据集上transformer的效果不如卷积,而在大数据集上,transformer效果优于卷积。

--------------------------------------------------分割线----------------------------------------------------------------------

继续补充对相对位置编码的理解,图片和思路来自B站霹雳吧啦Wz:

结合上面的知识,和上面的text一维不同,图片是二维的,同时拥有长和宽,对图片的每一个位置进行编码,第一个图片四个位置的索引,等于(0,0)-(0,0),(0,0)-(0,1),(0,0)-(1,0),(0,0)-(1,1),其他的三个图片同理。

相对位置编码 relative position encoding_第12张图片

将上面的相对位置索引进行展平,变成了如图4x4的矩阵。 

在相对位置偏移公式中,B是相对位置偏移,不是相对位置索引,我们根据每一个相对位置索引(4x4矩阵)去relative  position bias table去取相应的参数。

但是我们看到在蓝色方块的右边和红色方块的右边两个索引一样,所以不能简单的行列相加,并且在原文章中,我们使用的一元位置坐标,接下来我们进行转换:

首先:将数值限制到大于0。

相对位置编码 relative position encoding_第13张图片

接着行标,第一个数乘以2M-1。

相对位置编码 relative position encoding_第14张图片

 最后得到一个新的索引:相对位置编码 relative position encoding_第15张图片

 根据索引我们去relative  position bias table中查找对应的值。

相对位置编码 relative position encoding_第16张图片

 其中训练网络中针对训练的是relative  position bias table,即relative  position bias table是训练出来的,而窗口m固定,index也是固定的。索引的范围行(-1,0,1),列(-1,0,1)共九个。

得到的bias即公式中的B。

参考:霹雳吧啦Wz

你可能感兴趣的:(paper总结,深度学习,人工智能,python)