def get_rel_indices(self, num_patches: int) -> torch.Tensor:
img_size = int(num_patches ** .5)
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
indx = ind.repeat(img_size, img_size)
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
indd = indx ** 2 + indy ** 2
rel_indices[:, :, :, 2] = indd.unsqueeze(0)
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
device = self.qk.weight.device
return rel_indices.to(device)
首先由torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1,1)
产生绝对位置编码如[[0,1,2,3,4,5,6,7,8,9,10,11,12,13]
[-1,0,1,2,3,4,5,6,7,8,9,10,11,12]
[-2,-1,0,1,2,3,4,5,6,7,8,9,10,11]
...
...
[-13,-12,-11,-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0]
然后用repeat函数对绝对位置进行重复产生N*2的位置编码
[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13,0,1,2,3,4,5,6,7,8,9,10,11,12,13...]
...
对两个维度进行同样的操作
再用repeat_interleave函数对绝对位置进行重复,产生N*2的位置编码
[[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,...]
[-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...]
...]
第三种编码方式是由上面两种方式组合起来的indd = indx**2 + indy**2
然后将三种编码方式cat起来,通过一个映射将3通道映射成num_heads个数,这么做的原因是,多头注意力要进行head个头数的注意力,需要head个注意力矩阵,同个将三通道的位置矩阵映射成heads个,然后reshape成和多头注意力矩阵形状相同的矩阵,以便和多头注意力矩阵进行结合。