有关swin transformer相对位置编码的理解:

有关swin transformer相对位置编码的理解:

假设window_size是7*7

那么窗口中共有49个patch,共有49*49个相对位置,每个相对位置有两个索引对应x和y两个方向,每个索引值的取值范围是[-6,6]。(第0行相对第6行,x索引相对值为-6;第6行相对第0行,x索引相对值为6;所以索引取值范围是[-6,6])

    # get pair-wise relative position index for each token inside the window
    coords_h = torch.arange(self.window_size[0])
    coords_w = torch.arange(self.window_size[1])
    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
    coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
    # 2, Wh*Ww, Wh*Ww, https://www.cnblogs.com/sgdd123/p/7603004.html
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  
    # Wh*Ww, Wh*Ww, 2, [i,j,:]表示窗口内第i个patch相对于第j个patch的坐标
    relative_coords = relative_coords.permute(1, 2, 0).contiguous() 

此时,构建出来的relative_coords的shape是[49, 49, 2],[i, j, :]表示窗口内第i个patch相对于第j个patch的坐标。

由于此时索引取值范围中包含负值,可分别在每个方向上加上6,使得索引取值从0开始。此时,索引取值范围为[0,12]

    relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
    relative_coords[:, :, 1] += self.window_size[1] - 1

有了这些相对位置坐标之后,就可以根据这些坐标获取对应的position bias,即论文中公式(4)中的B:
Attention ⁡ ( Q , K , V ) = SoftMax ⁡ ( Q K T / d + B ) V \operatorname{Attention}(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V Attention(Q,K,V)=SoftMax(QKT/d +B)V
这个时候可以构建一个shape为[13,13]的table,则当相对位置为(i,j)时,B=table[i, j]。(i,j的取值范围都是[0, 12])

由于论文中使用的时multi-head-self-attention,所以table[i, j]的值应该是一个维度为num_heads的一维向量。

在代码中,实现如下:(注意,此时的table将二维的位置关系,合并为了一维的位置关系)

    # define a parameter table of relative position bias  # shape : 2*Wh-1 * 2*Ww-1, nH
    self.relative_position_bias_table = nn.Parameter(
        torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 

为了与table对应,根据相对位置坐标取值时,也需要将二维相对坐标(i, j)映射为一维相对坐标(i*13+j), 在代码中体现为:

 	relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

最后,就可以根据映射后的坐标来对B进行取值了:

    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

附注:

将二维相对坐标(i, j)映射为一维相对坐标时,最简单的映射方式是将i和j相加,但这样无法区分(0, 2)和(2, 0),因为相加的结果都是2;所以作者采用了i*13+j这种方式,其中13 = 2*window_size - 1, 即j取值的最大值。类似于将一个二维数组打平后,每个元素的位置。

你可能感兴趣的:(深度学习,transformer,pytorch,深度学习)