向某个方向滑动
torch.roll(input, shifts, dims=None) → Tensor
input (Tensor) | 输入张量 |
shifts (int 或 tuple of int) |
|
dims | 维度 |
a=torch.arange(16).reshape(4,4)
print(a)
torch.roll(a,shifts=(1,1),dims=(0,1))
'''
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[15, 12, 13, 14],
[ 3, 0, 1, 2],
[ 7, 4, 5, 6],
[11, 8, 9, 10]])
'''
a=torch.arange(16).reshape(4,4)
print(a)
torch.roll(a,shifts=(2,2),dims=(0,1))
'''
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[10, 11, 8, 9],
[14, 15, 12, 13],
[ 2, 3, 0, 1],
[ 6, 7, 4, 5]])
'''