目录
说明
定义数据
降维
增维
定位降维
(如需制作课程资源(幻灯片、实训手册、视频等)请私信给我)
squeeze()是降维函数,unsqueeze()是增维函数。具体使用方法如下:
import torch
m = torch.Tensor(1,2,3)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00], [3.0674e-41, 1.1210e-43, 0.0000e+00]]])
维度为
torch.Size([1, 2, 3])
squeeze()缺省函数值,是去除所有维度为1的维度。定义函数值,在后面的【定位增维】中介绍。
m = m.squeeze()
m
tensor([[1.0769e-02, 3.0674e-41, 6.9797e+00], [3.0674e-41, 1.1210e-43, 0.0000e+00]])
维度为
torch.Size([2, 3])
dim为unsqueeze()的函数值,这里值为0,表示在第1维插入一个维度
m = m.unsqueeze(0)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00], [3.0674e-41, 1.1210e-43, 0.0000e+00]]])
维度为
torch.Size([1, 2, 3])
dim为1,表示在第2维插入一个维度
m = m.unsqueeze(1)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00]], [[3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([2, 1, 3])
dim为2,表示在第3维插入一个维度
m = m.unsqueeze(2)
m
tensor([[[1.0769e-02], [3.0674e-41], [6.9797e+00]], [[3.0674e-41], [1.1210e-43], [0.0000e+00]]])
torch.Size([2, 3, 1])
dim为-1,表示在倒数第1维插入一个维度
m = m.unsqueeze(-1)
m
tensor([[[1.0769e-02], [3.0674e-41], [6.9797e+00]], [[3.0674e-41], [1.1210e-43], [0.0000e+00]]])
torch.Size([2, 3, 1])
dim为-2,表示在倒数第2维插入一个维度
m = m.unsqueeze(-2)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00]], [[3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([2, 1, 3])
dim为-3,表示在倒数第3维插入一个维度
m = m.unsqueeze(-3)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00], [3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([1, 2, 3])
squeeze() 缺省值为去掉所有维度为1的维度。
m = torch.Tensor(1,2,1,3,1)
m
tensor([[[[[6.9997e+00], [3.0674e-41], [7.0005e+00]]], [[[3.0674e-41], [8.9683e-44], [0.0000e+00]]]]])
torch.Size([1, 2, 1, 3, 1])
m = m.squeeze()
m
tensor([[6.9997e+00, 3.0674e-41, 7.0005e+00], [3.0674e-41, 8.9683e-44, 0.0000e+00]])
torch.Size([2, 3])
squeeze(0)若第一维度为1则去除第一维度(第一维度不为1则不去除)
m = torch.Tensor(1,2,1,3,1)
m
tensor([[[[[7.8503e+00], [3.0674e-41], [6.9897e+00]]], [[[3.0674e-41], [8.9683e-44], [0.0000e+00]]]]])
torch.Size([1, 2, 1, 3, 1])
m = m.squeeze(0)
m
tensor([[[[7.8503e+00], [3.0674e-41], [6.9897e+00]]], [[[3.0674e-41], [8.9683e-44], [0.0000e+00]]]])
torch.Size([2, 1, 3, 1])
squeeze(1)若第二维度为1则去除第二维度(第二维度不为1则不去除)
m = m.squeeze(1)
m
tensor([[[7.8503e+00], [3.0674e-41], [6.9897e+00]], [[3.0674e-41], [8.9683e-44], [0.0000e+00]]])
torch.Size([2, 3, 1])
squeeze(2)若第三维度为1则去除第二维度(第三维度不为1则不去除)
m = m.squeeze(2)
m
tensor([[7.8503e+00, 3.0674e-41, 6.9897e+00], [3.0674e-41, 8.9683e-44, 0.0000e+00]])
torch.Size([2, 3])