Swin Transformer中torch.roll()详解

torch.roll()这个函数看官方解释很懵,直接对照可视化来理解
参考:torch.roll 函数的理解

torch.roll(x, shifts=(40, 40), dims=(1, 2))
Swin Transformer中torch.roll()详解_第1张图片
这里img的shape是[1,56,56,96],即[B,H,W,C]格式。
dim=1,shift=40指的就是数据沿着H维度,将数据朝正反向滚动40,超出部分循环回到图像中
dim=2,shift=40指的就是数据沿着W维度,将数据朝正反向滚动40,超出部分循环回到图像中
这里的原点是左上角,H的正方向向下,W正方向向右
可视化代码:

import torch    
import numpy as np   
import matplotlib.pyplot as plt

shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)

if shift_size > 0:
    shifted_x = torch.roll(x, shifts=(40, 40), dims=(1, 2))
    #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    print("---------经过循环位移了---------")
else:
    shifted_x = x
   
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
    plt.title("non_shifted")
else:
    plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()

你可能感兴趣的:(transformer,深度学习,python)