numpy的squeeze、unsqueeze以及torch的expand函数

一、squeeze函数

  • 作用:从指定数组中删除长度为1的维度
  • 用法:
numpy.squeeze(a, axis=None) # 或者torch.squeeze(a, axis=None)
 - a为指定数组
 - axis不为None时,指定的维度必须是长度为1的单维度;为None时,删除所有长度为1的单维度
  • 举例
import numpy as np
a = np.arange(2).reshape(2, 1, 1)
# array([[[0]],
#       [[1]]])

np.squeeze(a).shape
# (2,)
np.squeeze(a, axis=1).shape
# (2, 1)
np.squeeze(a, axis=2).shape
# (2, 1)

二、unsqueeze函数

squeeze作用相反:在指定数组中加入长度为1的维度。用法类似。

三、expand函数

如下图:
numpy的squeeze、unsqueeze以及torch的expand函数_第1张图片

你可能感兴趣的:(深度学习,python基础基础知识)