torch.roll

x = torch.arange(1,10).view(3,3)
print(x,"\n")
print(torch.roll(x,shifts=1,dims=0)) #shift 表示偏移几次 ,dim表示沿着哪个维度 这里dim=0表示行
print(torch.roll(x,shifts=1,dims=0)==torch.roll(x,shifts=4,dims=0)) #可见shift是可以循环计数,以3为周期

#参数可以是组合形式,表示先后进行沿着地dim 0 ,1先后偏移1行,1列 ,注意:shift是整体循环偏移,不是取代指定位置的值
torch.roll(x,shifts=(1,1),dims=(0,1))

输出

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]) 

tensor([[7, 8, 9],
        [1, 2, 3],
        [4, 5, 6]])
        
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])
tensor([[9, 7, 8],
        [3, 1, 2],
        [6, 4, 5]])
x = torch.arange(1,10).view(3,3)
print(x,"\n")
print(torch.roll(x,shifts=1,dims=0)) #shift 正负表示正向偏移 和反向偏移 
print(torch.roll(x,shifts=-1,dims=0))

输出

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]) 

tensor([[7, 8, 9],
        [1, 2, 3],
        [4, 5, 6]])
tensor([[4, 5, 6],
        [7, 8, 9],
        [1, 2, 3]])

你可能感兴趣的:(pytorch,python,pytorch,python)