目录
1 view函数
1.1 指定变换后的维度
1.2 自动推理变换后的维度
1.3 将tensor展平成一维
2 reshape函数
2.1 指定变换后的维度
2.2 自动推理转换后的维度
2.3 将tensor展平成一维
2.4 使用tensor.reshape变换
3 squeeze函数
3.1 torch.squeeze去除所有为1的维度
3.2 torch.squeeze指定dim去除
3.3 tensor.squeeze去除为1的维度
4 unsqueeze函数
4.1 torch.unsqueeze指定dim插入新维度
4.2 tensor.unsqueeze指定dim插入新维度
5 transpose函数
5.1 torch.transpose转置指定维度
5.2 tensor.transpose转置指定维度
6 expand函数
7 repeat函数
8 permute函数
Pytorch张量维度变化是在构建模型过程中常用且重要的操作,本文从实际应用触发,详细介绍常用的维度变化方法,这些方法包含view、reshap、squeeze、unsqueeze、transpose等。
Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。
view函数的操作对象是Tensor类型,返回的对象类型也为Tensor类型
def view(self, *size: _int) -> Tensor: ...
更便于理解的表示形式:
view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。
通过手工指定,将一个一维tensor变换为3*8维的tensor
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
a2 = a1.view(3, 8)
print(a1)
print(a2)
print(a1.shape)
print(a2.shape)
运行程序显示如下:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24])
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16],
[17, 18, 19, 20, 21, 22, 23, 24]])
torch.Size([24])
torch.Size([3, 8])
如果某个参数为-1,则表示该维度取决于其它维度,由Pytorch自己补充
import torch
a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
a4 = a3.view(4, -1)
a5 = a3.view(2, 3, -1)
a6 = a3.view(-1, 3, 2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a3.shape)
print(a4.shape)
print(a5.shape)
print(a6.shape)
运行程序显示如下:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24]])
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16],
[17, 18]],
[[19, 20],
[21, 22],
[23, 24]]])
torch.Size([24])
torch.Size([4, 6])
torch.Size([2, 3, 4])
torch.Size([4, 3, 2])
import torch
a7 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
a8 = a6.view(-1)
print(a7)
print(a8)
print(a7.shape)
print(a8.shape)
运行程序显示如下:
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24])
torch.Size([2, 12])
torch.Size([24])
返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。
torch.reshape(input, shape) → [Tensor]
也可以直接在Tensor上使用reshape,形式如下:
tensor.reshape(shape) → [Tensor]
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = torch.reshape(a1, (3, 4))
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
运行程序显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
import torch
a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a4 = torch.reshape(a1, (-1, 6))
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
运行程序显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([2, 6])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
import torch
a5 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
a6 = torch.reshape(a1, (-1,))
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
运行程序显示如下:
torch.Size([2, 6])
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
improt torch
a7 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a8 = a7.reshape(6, 2)
a9 = a7.reshape(-1, 3)
a10 = a9.reshape(-1)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)
print(a9.shape)
print(a9)
print(a10.shape)
print(a10)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([6, 2])
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]])
torch.Size([4, 3])
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
将input张量中所有维度数据为1的维度给移除掉。指定了dim,如果dim对应维度的值不为1 ,则保持不变,为1则移除该维度。
torch.squeeze(input, dim=None) → [Tensor]
也可以在tensor上直接使用squeeze,形式如下:
tensor.squeeze(dim=None) → [Tensor]
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4)
a3 = torch.squeeze(a2)
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
运行结果显示如下:(a2的第二个维度被移除)
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
import torch
a4 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a5 = a1.reshape(3, 1, 4)
a6 = torch.squeeze(a5, 0)
a7 = torch.squeeze(a5, 1)
print(a4.shape)
print(a4)
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)
运行结果显示如下:(a5的第一个维度不为1,所以保持不变;a5的第二个维度为1,所以被移除)
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
import torch
a8 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a9 = a8.reshape(3, 1, 4)
a10 = a9.squeeze()
a11 = a9.squeeze(0)
a12 = a9.squeeze(1)
print(a8.shape)
print(a8)
print(a9.shape)
print(a9)
print(a10.shape)
print(a10)
print(a11.shape)
print(a11)
print(a12.shape)
print(a12)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾
torch.unsqueeze(input, dim) → [Tensor]
也可以在tensor上直接使用unsqueeze,形式如下:
torch.unsqueeze(dim) → [Tensor]
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 4)
a3 = torch.unsqueeze(a2, 0)
a4 = torch.unsqueeze(a2, 2)
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
torch.Size([1, 3, 4])
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 9],
[10],
[11],
[12]]])
import torch
a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a6 = a5.reshape(3, 4)
a7 = a6.unsqueeze(0)
a8 = a6.unsqueeze(1)
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)
运行结果显示如下:
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 4])
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
torch.Size([1, 3, 4])
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]]])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
返回 input 张量的转置,dim0与dim1交换位置
torch.transpose(input, dim0, dim1) → [Tensor]
也可以在tensor上直接使用unsqueeze,形式如下:
tensor.transpose(dim0, dim1) → [Tensor]
参数:
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(4, 3, 1)
a3 = torch.transpose(a2, 0, 1)
a4 = torch.transpose(a2, 1, 2)
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([4, 3, 1])
tensor([[[ 1],
[ 2],
[ 3]],
[[ 4],
[ 5],
[ 6]],
[[ 7],
[ 8],
[ 9]],
[[10],
[11],
[12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
[ 4],
[ 7],
[10]],
[[ 2],
[ 5],
[ 8],
[11]],
[[ 3],
[ 6],
[ 9],
[12]]])
torch.Size([4, 1, 3])
tensor([[[ 1, 2, 3]],
[[ 4, 5, 6]],
[[ 7, 8, 9]],
[[10, 11, 12]]])
import torch
a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a6 = a1.reshape(4, 3, 1)
a7 = a6.transpose(0, 1)
a8 = a6.transpose(1, 2)
print(a5.shape)
print(a5)
print(a6.shape)
print(a6)
print(a7.shape)
print(a7)
print(a8.shape)
print(a8)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([4, 3, 1])
tensor([[[ 1],
[ 2],
[ 3]],
[[ 4],
[ 5],
[ 6]],
[[ 7],
[ 8],
[ 9]],
[[10],
[11],
[12]]])
torch.Size([3, 4, 1])
tensor([[[ 1],
[ 4],
[ 7],
[10]],
[[ 2],
[ 5],
[ 8],
[11]],
[[ 3],
[ 6],
[ 9],
[12]]])
torch.Size([4, 1, 3])
tensor([[[ 1, 2, 3]],
[[ 4, 5, 6]],
[[ 7, 8, 9]],
[[10, 11, 12]]])
返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。
Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;
扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。
Tensor.expand( *sizes) → [Tensor]
参数:
- sizes (torch.Size or [int] – 指定维度复制的次数
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4, 1)
# 维度为 1 的 size 可以扩展成什么任意的 size
a3 = a2.expand(3, 5, 4, 2)
# -1 表示对应的维度size不变,但如果第一个维度3扩展成6则会报错,维度不为1不能扩展
a4 = a2.expand(-1, 5, -1, -1)
# 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为-1
a5 = a2.expand(2, -1, 5, -1, -1)
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
print(a5.shape)
print(a5)
运行结果显示如下 :(维度不为1则不能扩展)
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4, 1])
tensor([[[[ 1],
[ 2],
[ 3],
[ 4]]],
[[[ 5],
[ 6],
[ 7],
[ 8]]],
[[[ 9],
[10],
[11],
[12]]]])
torch.Size([3, 5, 4, 2])
tensor([[[[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4]],
[[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4]],
[[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4]],
[[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4]],
[[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4]]],
[[[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8]],
[[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8]],
[[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8]],
[[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8]],
[[ 5, 5],
[ 6, 6],
[ 7, 7],
[ 8, 8]]],
[[[ 9, 9],
[10, 10],
[11, 11],
[12, 12]],
[[ 9, 9],
[10, 10],
[11, 11],
[12, 12]],
[[ 9, 9],
[10, 10],
[11, 11],
[12, 12]],
[[ 9, 9],
[10, 10],
[11, 11],
[12, 12]],
[[ 9, 9],
[10, 10],
[11, 11],
[12, 12]]]])
torch.Size([3, 5, 4, 1])
tensor([[[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]]],
[[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]]],
[[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]]]])
torch.Size([2, 3, 5, 4, 1])
tensor([[[[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]]],
[[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]]],
[[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]]]],
[[[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]],
[[ 1],
[ 2],
[ 3],
[ 4]]],
[[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]],
[[ 5],
[ 6],
[ 7],
[ 8]]],
[[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]],
[[ 9],
[10],
[11],
[12]]]]])
根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据
Tensor.repeat( *sizes) → [Tensor]
参数:
- sizes (torch.Size or [int] – 指定维度复制的次数
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
print(a1.storage().data_ptr())
a2 = a1.reshape(3, 1, 4)
print(a2.storage().data_ptr())
a3 = a2.expand(3, 3, -1)
# expand 操作后,张量的内存地址没变
print(a3.storage().data_ptr())
a4 = a2.repeat(2, 4, 1)
# repeat 操作后,张量的内存地址会改变
print(a4.storage().data_ptr())
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
运行结果显示如下:
1974461518528
1974461518528
1974461518528
1974462302208
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([3, 3, 4])
tensor([[[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]]])
torch.Size([6, 4, 4])
tensor([[[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]],
[[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]]])
返回重新排列的张量
torch.permute(input, dims) → [Tensor]
也可以在tensor上直接使用permute,形式如下:
tensor.permute(dims) → [Tensor]
参数:
- input ([Tensor] 要重新排列的张量
- dims (tuple of python:int) 需要重排的维度索引数组
import torch
a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
a2 = a1.reshape(3, 1, 4)
a3 = torch.permute(a2, (2, 0, 1))
a4 = torch.permute(a2, (1, 0, 2))
a5 = a2.permute(1, 2, 0)
print(a1.shape)
print(a1)
print(a2.shape)
print(a2)
print(a3.shape)
print(a3)
print(a4.shape)
print(a4)
print(a5.shape)
print(a5)
运行结果显示如下:
torch.Size([12])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([3, 1, 4])
tensor([[[ 1, 2, 3, 4]],
[[ 5, 6, 7, 8]],
[[ 9, 10, 11, 12]]])
torch.Size([4, 3, 1])
tensor([[[ 1],
[ 5],
[ 9]],
[[ 2],
[ 6],
[10]],
[[ 3],
[ 7],
[11]],
[[ 4],
[ 8],
[12]]])
torch.Size([1, 3, 4])
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]]])
torch.Size([1, 4, 3])
tensor([[[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11],
[ 4, 8, 12]]])