transformer中的相对位置偏置的介绍(relative position bias)

前言

在很多近期的transformer工作中,经常提到一个词: relative position bias。用在self attention的计算当中。笔者在第一次看到这个概念时,不解其意,本文用来笔者自己关于relative position bias的理解。

笔者第一次看到该词是在swin transformer。后来在focal transformer和LG-transformer中都看到它。

relative position bias(相对位置偏置)

基本形式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T + B ) V Attention(Q, K, V) = Softmax(QK^T + B)V Attention(Q,K,V)=Softmax(QKT+B)V
其中 Q , V ∈ R n × d Q, V \in R^{n\times d} Q,VRn×d, B ∈ R n × n B \in R^{n \times n} BRn×n,n是token vector的数目。可以看出,B的作用是给attention map Q K T QK^T QKT的每个元素加了一个值。其本质就是希望attention map进一步有所偏重。因为attention map中某个值越低,经过softmax之后,该值会更低。对最终特征的贡献就低。

而B并不是一个随便初始化的参数,它有一个完备的使用过程。其基本过程如下:

  • 初始化一个 n 2 n^2 n2的tensor作为表,同时也是个参数。
  • 构建table index,用于根据位置查表。下面再介绍细节
  • 前向传播中使用位置查表。
  • 反向传播更新表。

在swin transformer的源代码中,可以清楚的看到相对位置偏置的使用过程。
在构造函数中,有以下相关内容:

self.relative_position_bias_table = nn.Parameter(
         torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
         
# get pair-wise relative position index for each token inside the window
 coords_h = torch.arange(self.window_size[0])
 coords_w = torch.arange(self.window_size[1])
 coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
 coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
 relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
 relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
 relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
 relative_coords[:, :, 1] += self.window_size[1] - 1
 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
 relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
 self.register_buffer("relative_position_index", relative_position_index)

第一行就是初始化表,
后面的内容就是建立一个可以根据query和key的相对位置查表参数的index。
比如现在有一个 2 × 2 2 \times 2 2×2的特征图。设置windows size 为(2,2),我们可以看看relative_position_index长什么样:
transformer中的相对位置偏置的介绍(relative position bias)_第1张图片

torch.Size([4, 4])
tensor([
[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]
])

注意看,主对角线都是4。上三角都是比4小,最低为0;下三角都比4大,且最大为8;一共9个数字,正好等于relative_position_bias_table的宽和高。
以第一行为例,第一个元素为4,第二个元素为3;对应就是方格图中标号为1和2的位置。
其实就是第一个query和第一个key都在标号1的位置,所以相对位置为0,则都使用参数表的第4个偏置;而第2个key中的元素,位置在标号1的右边一格,用参数表的第3个参数。
重点到了:只要query在key’的左边一格,relative_position_index中对应的位置都是3。比如第三行的最后一个数字。第三行对应标号3,只有标号4在其右边,而relative_position_bias_table[2][3]恰好为3。
进而可以观察其他元素之间的位置,可以发现相同的规律。因此,不难得出结论:B中值和query和key的相对位置有关系。相对位置一致的query-key pair,会采用相同的bias。

q = q * self.scale
attn = (q @ k.transpose(-2, -1))
# 根据index提供的相对位置映射查bias,然后在view成可以和attention map计算的B
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
   			self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)

reference

swin transformer
图解Swin Transformer

你可能感兴趣的:(深度学习,transformer,深度学习,pytorch)