torch.roll() 详解

torch.roll(inputshiftsdims=None)

  • input (Tensor) – the input tensor.

  • shifts (int or tuple of python:ints) – The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value

  • dims (int or tuple of python:ints) – Axis along which to roll

 

  • input (Tensor) –输入tensor

  • shifts (int or tuple of python:ints) – 变换的幅度,为整数或者元组。若为元组,其shape与dims保持一样

  • dims (int or tuple of python:ints) – 维度。在dims维上进行大小为shift的变换。0 为纵向,1为横向

 

代码示例: 

 

import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print(x, '\n')
#tensor([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])




print(1)
print(torch.roll(x, (-1 , -1), (0 , 0)), '\n')
#tensor([[7, 8, 9],
#        [1, 2, 3],
#        [4, 5, 6]])

print(2)
print(torch.roll(x, (-1 , 0), (0 , 1)), '\n')
#tensor([[4, 5, 6],
#        [7, 8, 9],
#        [1, 2, 3]])

print(3)
print(torch.roll(x, (-1 , 1), (1 , 0)), '\n')
#tensor([[8, 9, 7],
#        [2, 3, 1],
#        [5, 6, 4]])

print(4)
print(torch.roll(x, (0 , -1), (1 , 1)), '\n')
#tensor([[2, 3, 1],
#        [5, 6, 4],
#        [8, 9, 7]])

print(5)
print(torch.roll(x, (0 , 0), (1 , 1)), '\n')
#tensor([[1, 2, 3],
#        [4, 5, 6],
#        [7, 8, 9]])

print(6)
print(torch.roll(x, (0 , 1), (1 , 1)), '\n')
#tensor([[3, 1, 2],
#        [6, 4, 5],
#        [9, 7, 8]])

print(7)
print(torch.roll(x, (1 , -1), (1 , 0)), '\n')
#tensor([[6, 4, 5],
#        [9, 7, 8],
#        [3, 1, 2]])

print(8)
print(torch.roll(x, (1 , 0), (0 , 1)), '\n')
#tensor([[7, 8, 9],
#        [1, 2, 3],
#        [4, 5, 6]])

print(9)
print(torch.roll(x, (1 , 1), (0 , 0)), '\n')
#tensor([[4, 5, 6],
#        [7, 8, 9],
#        [1, 2, 3]])






x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
print(x, '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])


print(torch.roll(x, -3, 1), '\n')
#tensor([[2, 1],
#        [4, 3],
#        [6, 5],
#        [8, 7]])

print(torch.roll(x, -2, 1), '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])

print(torch.roll(x, -1, 1), '\n')
#tensor([[2, 1],
#        [4, 3],
#        [6, 5],
#        [8, 7]])

print(torch.roll(x, 0, 1), '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])

print(torch.roll(x, 1, 1), '\n')
#tensor([[2, 1],
#        [4, 3],
#        [6, 5],
#        [8, 7]])

print(torch.roll(x, 2, 1), '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])

print(torch.roll(x, 3, 1), '\n')
#tensor([[2, 1],
#        [4, 3],
#        [6, 5],
#        [8, 7]])






print(torch.roll(x, -3, 0), '\n')
#tensor([[7, 8],
#        [1, 2],
#        [3, 4],
#        [5, 6]])

print(torch.roll(x, -2, 0), '\n')
#tensor([[5, 6],
#        [7, 8],
#        [1, 2],
#        [3, 4]])

print(torch.roll(x, -1, 0), '\n')
#tensor([[3, 4],
#        [5, 6],
#        [7, 8],
#        [1, 2]])

print(torch.roll(x, 0, 0), '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])

print(torch.roll(x, 1, 0), '\n')
#tensor([[7, 8],
#        [1, 2],
#        [3, 4],
#        [5, 6]])

print(torch.roll(x, 2, 0), '\n')
#tensor([[5, 6],
#        [7, 8],
#        [1, 2],
#        [3, 4]])

print(torch.roll(x, 3, 0), '\n')
#tensor([[3, 4],
#        [5, 6],
#        [7, 8],
#        [1, 2]])






print(torch.roll(x, -3), '\n')
#tensor([[4, 5],
#        [6, 7],
#        [8, 1],
#        [2, 3]])

print(torch.roll(x, -2), '\n')
#tensor([[3, 4],
#        [5, 6],
#        [7, 8],
#        [1, 2]])

print(torch.roll(x, -1), '\n')
#tensor([[2, 3],
#        [4, 5],
#        [6, 7],
#        [8, 1]])

print(torch.roll(x, 0), '\n')
#tensor([[1, 2],
#        [3, 4],
#        [5, 6],
#        [7, 8]])

print(torch.roll(x, 1), '\n')
#tensor([[8, 1],
#        [2, 3],
#        [4, 5],
#        [6, 7]])

print(torch.roll(x, 2), '\n')
#tensor([[7, 8],
#        [1, 2],
#        [3, 4],
#        [5, 6]])

print(torch.roll(x, 3), '\n')
#tensor([[6, 7],
#        [8, 1],
#        [2, 3],
#        [4, 5]])

 

 

参考:https://pytorch.org/docs/master/generated/torch.roll.html 

 

你可能感兴趣的:(深度学习,python,torch)