[Python] 增加维度或者减少维度:a.squeeze(axis) 和 a.unsqueeze(axis)

a.squeeze(axis)a.unsqueeze(axis) 用于对tensor升维和降维,括号内的参数表示代表 需要处理的维度

参数含义

  • 0:第一个维度
  • 1:第二个维度
  • ···
  • -1:倒数第一个维度
  • -2:倒数第二个维度

a.squeeze(axis):降维

需要降低的维度,必须为1,不为1时操作无效果。

a.squeeze(axis)默认是将 a 中所有为 1 的维度删掉。

x = np.array([[[0], [1], [2]]])
# x=
# [[[0]
#  [1]
#  [2]]]

print(x.shape)
# (1, 3, 1)

x1 = np.squeeze(x)    #  a 中所有为 1 的维度删掉
print(x1)
# [0 1 2]
print(x1.shape)  
# (3,)

a.unsqueeze(axis):升维

参数指定的位置增加一个维度。

print(x.shape)
# torch.Size([4, 3])

x = x.unsqueeze(0) # 在第一维增加
print(x.shape)
# torch.Size([1, 4, 3])

你可能感兴趣的:(Python)