Swin Transformer代码中对relative_bias-index的理解(pytorch)

在查看pytorch实现的Swin Transformer源码中,对relative_bias-index编程部分有些不太理解,可能是因为自己python基础差的原因,导致有些代码看不懂。后面经过查看资料才弄懂了。在此记录下自己学习的过程。

我查看的是B站up主霹雳吧啦Wz视频中提供的代码,代码网址如下:pytorch_classification/swin_transformer

在源码的第218-231行实现了对relative_bias-index的创建。以下为代码的截图:

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第1张图片

比如window_size是2×2(原论文中为7×7还有其他的可供选择,我这里为简化,设置为(2×2),那么218和219行分别使用arange函数得到两个列表,从0开始到window_size[0]-1的列表和从0开始到window_size[1]-1的列表:

coords_h=[0,1],coords_w=[0,1]

代码221行首先对coords_h和coords_w使用meshgrid转化为网格,然后进行stack拼接。关于meshgrid方法,这个博客把原理讲解的很清楚(https://blog.csdn.net/qq_41375609/article/details/102828154)。

对coords_h和coords_w使用meshgrid方法后,会得到两个二维矩阵,分别记为:

然后对x和y进行stack操作(默认是第0维度),会把x和y在channel方向进行拼接(红色的矩阵在黑色矩阵的后面),得到一个[2,2,2]的矩阵。

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第2张图片

第 222行对上面的两个矩阵进行展平操作(把每个矩阵展平为一个行向量),得到一个[2,4]的矩阵:

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第3张图片
Swin Transformer代码中对relative_bias-index的理解(pytorch)_第4张图片
Swin Transformer代码中对relative_bias-index的理解(pytorch)_第5张图片

在深度方向复制4倍,变为[2,4,4]的张量,如下图所示:

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第6张图片
Swin Transformer代码中对relative_bias-index的理解(pytorch)_第7张图片

将两个张量转化为相同维度后,就能进行相减了,得到的结果也为[2,4,4]:

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第8张图片

然后进行226行操作,对上面的张量relative_coords进行permute(1, 2, 0),就是将张量的形状变为[4,4,2],如下图所示:

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第9张图片

relative_coords[:,:,0]代表所有的黑色元素,也就是所有relative_bias-index的行标; relative_coords[:,:,1]代表所有的红色元素,也就是所有relative_bias-index的列标。

第227行是对所有relative_bias-index的行标+window_size[0] - 1 ,第228行是对所有relative_bias-index的列标+window_size[1] - 1 。第229行是对所有relative_bias-index的行标都+2×window_size[0] - 1。第230行是对经过一系列行标列标计算的relative_bias-index的行标与列标进行相加,得到一个[4,4]的矩阵,也就是最终的relative_bias-index。

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第10张图片

根据relative_bias-index的索引内容,从relative_position_bias_table(表中的bias是可训练的参数)中取相应位置的bias,即为对应位置的偏置。(下图来自B站up主霹雳吧啦Wz的博客,地址为(https://blog.csdn.net/qq_37541097/article/details/121119988?spm=1001.2014.3001.5502

Swin Transformer代码中对relative_bias-index的理解(pytorch)_第11张图片

你可能感兴趣的:(计算机视觉,深度学习,python)