ConviT中GPSA位置注意力

def get_rel_indices(self, num_patches: int) -> torch.Tensor:
        img_size = int(num_patches ** .5)
        rel_indices = torch.zeros(1, num_patches, num_patches, 3)
        ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
        indx = ind.repeat(img_size, img_size)
        indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
        indd = indx ** 2 + indy ** 2
        rel_indices[:, :, :, 2] = indd.unsqueeze(0)
        rel_indices[:, :, :, 1] = indy.unsqueeze(0)
        rel_indices[:, :, :, 0] = indx.unsqueeze(0)
        device = self.qk.weight.device
        return rel_indices.to(device)

首先由torch.arange(img_size).view(1,-1)  - torch.arange(img_size).view(-1,1)

产生绝对位置编码如[[0,1,2,3,4,5,6,7,8,9,10,11,12,13]

                                 [-1,0,1,2,3,4,5,6,7,8,9,10,11,12]

                                 [-2,-1,0,1,2,3,4,5,6,7,8,9,10,11]

                                                         ...

                                                         ...

                                 [-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0]

然后用repeat函数对绝对位置进行重复产生N*2的位置编码

[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13...]

...

对两个维度进行同样的操作

再用repeat_interleave函数对绝对位置进行重复,产生N*2的位置编码

[[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...]

[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...]

...]

第三种编码方式是由上面两种方式组合起来的indd = indx**2 + indy**2

然后将三种编码方式cat起来,通过一个映射将3通道映射成num_heads个数,这么做的原因是,多头注意力要进行head个头数的注意力,需要head个注意力矩阵,同个将三通道的位置矩阵映射成heads个,然后reshape成和多头注意力矩阵形状相同的矩阵,以便和多头注意力矩阵进行结合。

你可能感兴趣的:(pytorch,python,深度学习,神经网络,transformer)