论文中对于这一块的描述不是很清楚,特意记录一下学习过程。
这篇博客讲解的很清楚,请参考阅读https://blog.csdn.net/qq_37541097/article/details/121119988
x= torch.arange(2)
y= torch.arange(2)
#输入为一维序列,输出两个二维网格,常用来生成坐标
ox,oy = torch.meshgrid([x,y])
#按照某个维度拼接,输入序列shape必须一致,默认按照dim0
o2 = torch.stack((ox,oy))
print(ox,oy)
print(o2,o2.shape)
coords = torch.flatten(o2,1)
print(coords,coords.shape)
输出
tensor([[0, 0],
[1, 1]]) tensor([[0, 1],
[0, 1]])
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]]) torch.Size([2, 2, 2])
#得到2行序列,对应x,y轴的坐标
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
torch.Size([2, 4])
print(coords[:,:,None].shape) #相当于增加一个维度
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
#作用与unsqueeze()相同
coords.unsqueeze(1)==coords[:,None,:]
输出
torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4, 1])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
print(coords[:,:,None]) #相当于增加一个维度
print(coords[:,None,:])
输出
tensor([[[0],
[0],
[1],
[1]],
[[0],
[1],
[0],
[1]]])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
relative_coords=coords[:,:,None]-coords[:,None,:] #(2,16,1)-(2,1,16) #广播机制相减
print(f"relative_coords:{relative_coords.shape}={coords[:,:,None].shape}-{coords[:,None,:].shape }","\n",{relative_coords})
输出
#这里相减,应该是使用了广播机制,先扩展到相同shape后,再进行元素相减运算
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4])
{tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])}
转换为[4,4,2],相当于得到4个4*2的坐标对,一行横坐标,一行纵坐标
relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)
输出
torch.Size([4, 4, 2])
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
print(relative_coords[:,:,0]) #输出第一列元素对应输入中第一列的第1个元素集合 ,第二列对应输入第一列的第2个元素集合
print(relative_coords[:,:,1])
输出
tensor([[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]])
tensor([[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]])
window_size=(2,2)
#行、列元素都加上M-1 ,这里M=2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print(relative_position_index)
输出
#第一列(行)加M-1
tensor([[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]],
[[ 2, 0],
[ 2, -1],
[ 1, 0],
[ 1, -1]],
[[ 2, 1],
[ 2, 0],
[ 1, 1],
[ 1, 0]]])
# 继续第2列 (列) 加M-1
tensor([[[1, 1],
[1, 0],
[0, 1],
[0, 0]],
[[1, 2],
[1, 1],
[0, 2],
[0, 1]],
[[2, 1],
[2, 0],
[1, 1],
[1, 0]],
[[2, 2],
[2, 1],
[1, 2],
[1, 1]]])
#第一列 (行) 乘 2M-1(3)
tensor([[[3, 1],
[3, 0],
[0, 1],
[0, 0]],
[[3, 2],
[3, 1],
[0, 2],
[0, 1]],
[[6, 1],
[6, 0],
[3, 1],
[3, 0]],
[[6, 2],
[6, 1],
[3, 2],
[3, 1]]])
#行列元素相加
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
这里就得到相对位置索引,这里对应的值需要到relative positional bias Table 中获取,一开始程序中就定一个了一个可学习的table,长度为[2M-1]*[2M-1], 这里M=2,也就是长度为9,正对应上边索引0-8
# define a parameter table of relative position bias
#构造可学习的相对位置偏置table,长度为 (2H-1)*(2W-1)*(num_head)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
这里假设有两个attention头
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2)) # 2*Wh-1 * 2*Ww-1, nH 假设有两个attn头
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) #初始化bias_table
输出
torch.Size([9, 2]) #两个attn头,每个头(2M-1)*(2M-1)个数
Parameter containing:
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], requires_grad=True)
Parameter containing: #初始化后的数据
tensor([[-0.0340, 0.0181],
[-0.0033, -0.0055],
[ 0.0045, 0.0193],
[ 0.0412, -0.0031],
[ 0.0004, -0.0032],
[ 0.0201, -0.0161],
[ 0.0067, 0.0079],
[ 0.0241, -0.0279],
[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]
print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table 根据索引取值后的数据:\n",relative_position_bias.shape,"\n",relative_position_bias)
relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1) # Wh*Ww,Wh*Ww,nH
print("维度变换:\n",relative_position_bias.shape,"\n",relative_position_bias)
#转换为与attention shape一致
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
index :
torch.Size([16])
tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4]) #索引展开成一维
bias table 根据索引取值后的数据:
torch.Size([16, 2])
tensor([[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055],
[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
维度变换:
torch.Size([4, 4, 2])
tensor([[[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181]],
[[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055]],
[[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031]],
[[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)
以上代码就是有关相对位置偏置的全部内容了。