x = torch.arange(1,10).view(3,3)
print(x,"\n")
print(torch.roll(x,shifts=1,dims=0)) #shift 表示偏移几次 ,dim表示沿着哪个维度 这里dim=0表示行
print(torch.roll(x,shifts=1,dims=0)==torch.roll(x,shifts=4,dims=0)) #可见shift是可以循环计数,以3为周期
#参数可以是组合形式,表示先后进行沿着地dim 0 ,1先后偏移1行,1列 ,注意:shift是整体循环偏移,不是取代指定位置的值
torch.roll(x,shifts=(1,1),dims=(0,1))
输出
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6]])
tensor([[True, True, True],
[True, True, True],
[True, True, True]])
tensor([[9, 7, 8],
[3, 1, 2],
[6, 4, 5]])
x = torch.arange(1,10).view(3,3)
print(x,"\n")
print(torch.roll(x,shifts=1,dims=0)) #shift 正负表示正向偏移 和反向偏移
print(torch.roll(x,shifts=-1,dims=0))
输出
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6]])
tensor([[4, 5, 6],
[7, 8, 9],
[1, 2, 3]])