本文转载于https://heroinlin.github.io/2018/03/12/Pytorch/Pytorch_tensor_flip/
在使用numpy时我们可以对数组进行镜像翻转操作,如以下例子
import numpy as np
array = np.array(range(10))
print(array)
print(array[::-1])
[0 1 2 3 4 5 6 7 8 9]
[9 8 7 6 5 4 3 2 1 0]
但是在pytorch中并不能通过tensor[::-1]进行镜像的翻转,此处给出了tensor的镜像翻转方法
# https://github.com/pytorch/pytorch/issues/229
import torch
from torch.autograd import Variable
def flip(x, dim):
xsize = x.size()
dim = x.dim() + dim if dim < 0 else dim
x = x.view(-1, *xsize[dim:])
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
return x.view(xsize)
# Code to test it with cpu Variable
a = Variable(torch.Tensor([range(1, 25)]).view(1, 2, 3, 4))
print(a)
print(flip(a, 0)) # Or -4
print(flip(a, 1)) # Or -3
print(flip(a, 2)) # Or -2
print(flip(a, 3)) # Or -1
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]]])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]]])
tensor([[[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]],
[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]]]])
tensor([[[[ 9., 10., 11., 12.],
[ 5., 6., 7., 8.],
[ 1., 2., 3., 4.]],
[[21., 22., 23., 24.],
[17., 18., 19., 20.],
[13., 14., 15., 16.]]]])
tensor([[[[ 4., 3., 2., 1.],
[ 8., 7., 6., 5.],
[12., 11., 10., 9.]],
[[16., 15., 14., 13.],
[20., 19., 18., 17.],
[24., 23., 22., 21.]]]])
以下是pytorch>=0.4.0的代码
# https://github.com/pytorch/pytorch/issues/229
import torch
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]