torch.
roll
(input, shifts, dims=None)input (Tensor) – the input tensor.
shifts (int or tuple of python:ints) – The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value
dims (int or tuple of python:ints) – Axis along which to roll
input (Tensor) –输入tensor
shifts (int or tuple of python:ints) – 变换的幅度,为整数或者元组。若为元组,其shape与dims保持一样
dims (int or tuple of python:ints) – 维度。在dims维上进行大小为shift的变换。0 为纵向,1为横向
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print(x, '\n')
#tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(1)
print(torch.roll(x, (-1 , -1), (0 , 0)), '\n')
#tensor([[7, 8, 9],
# [1, 2, 3],
# [4, 5, 6]])
print(2)
print(torch.roll(x, (-1 , 0), (0 , 1)), '\n')
#tensor([[4, 5, 6],
# [7, 8, 9],
# [1, 2, 3]])
print(3)
print(torch.roll(x, (-1 , 1), (1 , 0)), '\n')
#tensor([[8, 9, 7],
# [2, 3, 1],
# [5, 6, 4]])
print(4)
print(torch.roll(x, (0 , -1), (1 , 1)), '\n')
#tensor([[2, 3, 1],
# [5, 6, 4],
# [8, 9, 7]])
print(5)
print(torch.roll(x, (0 , 0), (1 , 1)), '\n')
#tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(6)
print(torch.roll(x, (0 , 1), (1 , 1)), '\n')
#tensor([[3, 1, 2],
# [6, 4, 5],
# [9, 7, 8]])
print(7)
print(torch.roll(x, (1 , -1), (1 , 0)), '\n')
#tensor([[6, 4, 5],
# [9, 7, 8],
# [3, 1, 2]])
print(8)
print(torch.roll(x, (1 , 0), (0 , 1)), '\n')
#tensor([[7, 8, 9],
# [1, 2, 3],
# [4, 5, 6]])
print(9)
print(torch.roll(x, (1 , 1), (0 , 0)), '\n')
#tensor([[4, 5, 6],
# [7, 8, 9],
# [1, 2, 3]])
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
print(x, '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, -3, 1), '\n')
#tensor([[2, 1],
# [4, 3],
# [6, 5],
# [8, 7]])
print(torch.roll(x, -2, 1), '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, -1, 1), '\n')
#tensor([[2, 1],
# [4, 3],
# [6, 5],
# [8, 7]])
print(torch.roll(x, 0, 1), '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, 1, 1), '\n')
#tensor([[2, 1],
# [4, 3],
# [6, 5],
# [8, 7]])
print(torch.roll(x, 2, 1), '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, 3, 1), '\n')
#tensor([[2, 1],
# [4, 3],
# [6, 5],
# [8, 7]])
print(torch.roll(x, -3, 0), '\n')
#tensor([[7, 8],
# [1, 2],
# [3, 4],
# [5, 6]])
print(torch.roll(x, -2, 0), '\n')
#tensor([[5, 6],
# [7, 8],
# [1, 2],
# [3, 4]])
print(torch.roll(x, -1, 0), '\n')
#tensor([[3, 4],
# [5, 6],
# [7, 8],
# [1, 2]])
print(torch.roll(x, 0, 0), '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, 1, 0), '\n')
#tensor([[7, 8],
# [1, 2],
# [3, 4],
# [5, 6]])
print(torch.roll(x, 2, 0), '\n')
#tensor([[5, 6],
# [7, 8],
# [1, 2],
# [3, 4]])
print(torch.roll(x, 3, 0), '\n')
#tensor([[3, 4],
# [5, 6],
# [7, 8],
# [1, 2]])
print(torch.roll(x, -3), '\n')
#tensor([[4, 5],
# [6, 7],
# [8, 1],
# [2, 3]])
print(torch.roll(x, -2), '\n')
#tensor([[3, 4],
# [5, 6],
# [7, 8],
# [1, 2]])
print(torch.roll(x, -1), '\n')
#tensor([[2, 3],
# [4, 5],
# [6, 7],
# [8, 1]])
print(torch.roll(x, 0), '\n')
#tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
print(torch.roll(x, 1), '\n')
#tensor([[8, 1],
# [2, 3],
# [4, 5],
# [6, 7]])
print(torch.roll(x, 2), '\n')
#tensor([[7, 8],
# [1, 2],
# [3, 4],
# [5, 6]])
print(torch.roll(x, 3), '\n')
#tensor([[6, 7],
# [8, 1],
# [2, 3],
# [4, 5]])
参考:https://pytorch.org/docs/master/generated/torch.roll.html