torch.transpose(input, dim0, dim1) → Tensor
功能:将输入数组的dim0
维度和dim1
维度交换
输入:
input
:需要做维度交换的数组dim0
、dim1
:交换的维度注意:
torch.transpose
也可以通过a.transpose
实现,后者默认转换数组a
view
改变视图,则会报错二维数组转置
import torch
a=torch.arange(20).reshape(4,5)
b=torch.transpose(a,dim0=1,dim1=0)
print(a)
print(b)
输出
# 原数组
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
# 转置后
tensor([[ 0, 5, 10, 15],
[ 1, 6, 11, 16],
[ 2, 7, 12, 17],
[ 3, 8, 13, 18],
[ 4, 9, 14, 19]])
高维数组转置
import torch
a=torch.arange(24).reshape(2,3,4)
b=torch.transpose(a,dim0=0,dim1=1)
c=torch.transpose(a,dim0=0,dim1=2)
d=torch.transpose(a,dim0=1,dim1=2)
print(a)
print(a.shape)
print(b)
print(b.shape)
print(c)
print(c.shape)
print(d)
print(d.shape)
输出
# 原数组
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
torch.Size([2, 3, 4])
# 1、2维度交换
tensor([[[ 0, 1, 2, 3],
[12, 13, 14, 15]],
[[ 4, 5, 6, 7],
[16, 17, 18, 19]],
[[ 8, 9, 10, 11],
[20, 21, 22, 23]]])
torch.Size([3, 2, 4])
# 1、3维度交换
tensor([[[ 0, 12],
[ 4, 16],
[ 8, 20]],
[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]]])
torch.Size([4, 3, 2])
# 2、3维度交换
tensor([[[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]],
[[12, 16, 20],
[13, 17, 21],
[14, 18, 22],
[15, 19, 23]]])
torch.Size([2, 4, 3])
存储地址不连续
import torch
a=torch.arange(24).reshape(2,3,4)
b=torch.transpose(a,dim0=0,dim1=1)
c=b.view(2,3,4)
会报错,RuntimeError
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
换成reshape
方法来改变视图或者加一步.contiguous()
操作,让存储内存连续
import torch
a=torch.arange(24).reshape(2,3,4)
b=torch.transpose(a,dim0=0,dim1=1)
c=b.reshape(4,6)
# 让b的存储变连续
b=b.contiguous()
d=b.view(4,6)
print(c)
print(d)
输出
# 两种方式结果一样
tensor([[ 0, 1, 2, 3, 12, 13],
[14, 15, 4, 5, 6, 7],
[16, 17, 18, 19, 8, 9],
[10, 11, 20, 21, 22, 23]])
tensor([[ 0, 1, 2, 3, 12, 13],
[14, 15, 4, 5, 6, 7],
[16, 17, 18, 19, 8, 9],
[10, 11, 20, 21, 22, 23]])
tensor.permute(*dims) → Tensor
功能:将数组tensor
的维度按输入dims
的顺序进行交换
输入:
dims
:维度交换顺序注意:
permute
方法只能通过tensor.permute
实现,不能通过torch.permute
实现transpose
类似经过交换后的内存地址不连续,如果用view
改变视图,则会报错tensor.permute
的功能与np.transpose
类似,均可以同时对一个数组进行多维度交换操作import torch
a=torch.arange(24).reshape(2,3,4)
# 0,1,2->1,2,0
b=a.permute(1,2,0)
print(a)
print(a.shape)
print(b)
print(b.shape)
输出
# 原数组
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
torch.Size([2, 3, 4])
# 交换维度后数组
tensor([[[ 0, 12],
[ 1, 13],
[ 2, 14],
[ 3, 15]],
[[ 4, 16],
[ 5, 17],
[ 6, 18],
[ 7, 19]],
[[ 8, 20],
[ 9, 21],
[10, 22],
[11, 23]]])
torch.Size([3, 4, 2])
permute
一次可以操作多个维度,并且必须传入所有维度数;而transpose
只能同时交换两个维度,并且只能传入两个数permute
可以通过多个transpose
实现transpose
传入的dim
无顺序之分,传入(1,0)和(0,1)结果一样,都是第一维度和第二维度进行交换;permute
传入的dim
有顺序之分,传入(0,1)代表交换后原第一维度在前面,原第二维度在后面;传入(1,0)代表交换后原第二维度在前面,原第一维度在后面在计算机视觉中,由于cv2格式(numpy)读取的图片为H×W×C,通道数在最后;而torch中的图片常以C×H×W存在,因此,当图片在tensor与numpy之间转换时,需要用到交换数组维度的函数来将图片存储格式转化
torch.transpose():https://pytorch.org/docs/stable/generated/torch.transpose.html?highlight=transpose#torch.transpose
toch.permute():https://pytorch.org/docs/1.9.1/generated/torch.Tensor.permute.html?highlight=torch%20permute#
点个赞再走吧