目录
1. torch.transpose(dim0, dim1)
2. torch.permute(dims)
3. 在转置函数之后使用view()
PyTorch中使用torch.transpose() 和 torch.permute()可对张量进行维度的转置,具体内容如下:
参数: 需要转置的两个维度,即dim0和dim1,顺序可以互换,但一次只能进行两个维度的转换
示例:将shape为[3, 4]的张量,通过transpose()转换成维度为[4, 3]的张量
x = torch.rand(3, 4)
print(x.shape) # torch.Size([3, 4])
y1 = x.transpose(0, 1)
print(y1.shape) # torch.Size([4, 3])
y2 = x.transpose(1, 0)
print(y2.shape) # torch.Size([4, 3])
参数: dims,可以是list, tuple等,表示进行转置的时候可以一次性进行多维度的转置,且必须传入所有维度数,参数有顺序关系,参数的顺序表示原tensor的哪一维。
示例1:进行二维转置
x.permute(1, 0)表示转置后x的shape为[原tensor第1维,原tensor第0维],也就是[4, 3]。更简单的理解就是把原来的维度比作一个数组 list=[3, 4],permute中的数表示数组对应的index,则转置后的x为 [list[1], list[0]] = [4, 3]
x = torch.rand(3, 4)
print(x.shape) # torch.Size([3, 4])
y1 = x.permute(1, 0)
print(y1.shape) # torch.Size([4, 3])
示例2:进行多维度同时转置。按照上面数组的方式进行理解,x的原维度表示数组 array = [3, 4, 5, 8],转置后的维度 = [array[1], array[0], array[3], array[2]] = [4, 3, 8, 5]
x = torch.rand(3, 4, 5, 8)
print(x.shape) # torch.Size([3, 4, 5, 8])
y1 = x.permute(1, 0, 3, 2)
print(y1.shape) # torch.Size([4, 3, 8, 5])
注意:在使用transpose或permute之后,若要使用view()改变其形状,必须先contiguous()。或者直接使用reshape()直接代替contiguous().view(),具体原因可见本博客PyTorch系列中的 Pytorch基础 - 6. torch.reshape() 和 torch.view()
示例1:直接使用view(),会报错
示例2:使用contiguous().view(),未报错,且得到了输出结果
x = torch.rand(3, 4)
print(x.shape) # torch.Size([3, 4])
y1 = x.permute(1, 0)
print(y1.shape) # torch.Size([4, 3])
z = y1.contiguous().view(-1)
print(z.shape) # torch.Size([12])