torch.squeeze、 numpy.squeeze()详解

1. 降低维度 : squeeze、 numpy.squeeze()

torch版:

x.squeeze()

如果不传位置,则删除所有大小为1的维度。
例如,如果输入的形状为:(A×1×B×C×1×D)则输出张量的形状为(A×B×C×D)
传入位置,若矩阵的相应位置的维度为1,则删除,其他维度不变;若不为1,则不变。
总结:只能维度为1时,才能降低相应位置维度

torch示例:

import torch
import numpy as np

x = torch.rand((2, 2, 1, 3, 1, 3))
# 不传入位置,删除矩阵中所有维度为1的
b = x.squeeze()
# 如果1处维度为1,则删除此维度,其他维度不变
c = x.squeeze(1)
# 如果2处维度为1,则删除此维度,其他维度不变
d = x.squeeze(2)
print('x_shape:', x.shape)  # torch.Size([2, 2, 1, 3, 1, 3])
print('b_shape:', b.shape)  # b_shape: torch.Size([2, 2, 3, 3])
print('c_shape:', c.shape)  # c_shape: torch.Size([2, 2, 1, 3, 1, 3])
print('d_shape:', d.shape)  # d_shape: torch.Size([2, 2, 3, 1, 3])

numpy版:

从ndarray的shape中,去掉维度为1的。

默认去掉所有的1。

注意:只能去掉shape中的1,其他数字的维度无法去除。

当传入的轴对应维度不为1时,np会报错,报错如下:

ValueError: cannot select an axis to squeeze out which has size not equal to one

但是torch中不会报错,会返回原数据,见上面示例。

numpy示例:


import numpy as np

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = x.reshape(1, 1, 1, 2, 5)

a = x.squeeze()
b = x.squeeze(0)
c = x.squeeze(1)
d = x.squeeze(2)
# 维度不为1时,进行删除会报错
# e = x.squeeze(4)
print('x_shape:', x.shape)  # (1, 1, 1, 2, 5)
print('a_shape:', a.shape)  # (2, 5)
print('b_shape:', b.shape)  # (1, 1, 2, 5)
print('c_shape:', c.shape)  # (1, 1, 2, 5)
print('d_shape:', d.shape)  # (1, 1, 2, 5)
# print('e_shape:', e.shape)  

你可能感兴趣的:(人工智能,numpy,python,深度学习)