torch.squeeze(input, dim=None, out=None):对数据的维度进行压缩,去掉维数为1的的维度。
squeeze函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
squeeze(0):代表若第一维度值为1则去除第一维度,例如 a.squeeze(0),a 为 torch.tensor() 格式张量。
squeeze(1):代表若第二维度值为1则去除第二维度
squeeze(-1):去除最后维度值为1的维度
torch.unsqueeze (input, dim=None, out=None):对数据的维度进行扩容,即升维。
使用格式可以是torch.unsqueeze(x, 0)
,也可为是x.unsqueeze(0)
。
a = torch.Tensor(1, 3)
print(a)
print(a.squeeze(0))
print(a.squeeze(1))
b = torch.Tensor(2, 3)
print(b)
print(b.squeeze(0))
print(b.squeeze(1))
c = torch.Tensor(3, 1)
print(c)
print(c.squeeze(0))
print(c.squeeze(1))
x = torch.tensor([1, 2, 3, 4])
print(x)
print(torch.unsqueeze(x, 0))
print(torch.unsqueeze(x, 1))
定义张量 a,为 2 维,第一维度有 1 个元素,第二维度有 3 个元素。
输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
通过 a.squeeze(0) 对第一维度进行降维,此时第一维度有 1 个元素,可降维,第一维度消失,第二维度自动变成第一维度有三个元素,与 a 相比,即消失了一层 “[]”。
输出:tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
通过 a.squeeze(1) 对第二维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 a 相同。
输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
定义张量 b,为 2 维,第一维度有 2 个元素,第二维度有 3 个元素。
第一、二维度均不可降维,因为三次输出相同。
输出:
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
定义张量 c,为 2 维,第一维度有 3 个元素,第二维度有 1 个元素。
输出:
tensor([[0.0000e+00],
[ nan],
[5.2781e-24]])
通过 c.squeeze(0) 对第一维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 c 相同。
输出:
tensor([[0.0000e+00],
[ nan],
[5.2781e-24]])
通过 c.squeeze(1) 对第二维度进行降维,此时第二维度有 1 个元素,可降维,第二维度消失,第二维度数值自动进入第一维度中。
输出:
tensor([0.0000e+00, nan, 5.2781e-24])
定义张量 x,为 1 维,其中数值依次为 1, 2, 3, 4。
输出:tensor([1, 2, 3, 4])
通过 x.unsqueeze(0) 于第一维度位置增加一个维度,使原张量变成 2 维,维度变为 (1, 4)。与 x 相比,即增加了一层 “[]”。
输出:tensor([[1, 2, 3, 4]])
通过 x.unsqueeze(1) 于第二维度位置增加一个维度,使原张量变成 2 维,维度变为 (4, 1)。
输出:tensor([[1], [2], [3], [4]])
tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0., 0., 0.],
[0., 0., 0.]])
tensor([[0.0000e+00],
[ nan],
[5.2781e-24]])
tensor([[0.0000e+00],
[ nan],
[5.2781e-24]])
tensor([0.0000e+00, nan, 5.2781e-24])
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
[2],
[3],
[4]])