torch.unsqueeze() 和 torch.squeeze()详解

1. torch.squeeze

torch.squeeze(input, dim=None, *, out=None)

  • input:输入的张量
  • dim:选择需要降维的维度,默认是None
    squeeze的主要作用是降维。

示例展示

x = torch.zeros(2, 1, 2, 1, 2)
print(x.shape)
x = x.squeeze()
print(x)
print(x.shape)

在这里插入图片描述

x = torch.zeros(2, 1, 2, 1, 2)
print(x.shape)
x = x.squeeze(0)
print(x.shape)
x = x.squeeze(1)
print(x.shape)

在这里插入图片描述
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。
在多维张量中,如果某一个维度是1,那么这个维度是为了扩充维度,所以为了加快计算,进行降维操作时可以去掉1的维度。

2. torch.unsqueeze

torch.squeeze是为了降维,那么torch.unsqueeze是了升维。
torch.unsqueeze(input, dim)

  • input:输入的张量
  • dim:插入维度的索引,默认是None

示例展示

x = torch.tensor([1, 2, 3, 4])
print(x)
print(x.size())
print('*'*50)
x = x.unsqueeze(1)
print(x)
print(x.size())

torch.unsqueeze() 和 torch.squeeze()详解_第1张图片

3. squeeze_和unsqueeze_

squeeze_和unsqueeze_分别在squeeze和unsqueeze的基础上增加下划线,区别在于是否改变原来张量。
加上“_”,将会直接改变原始张量,否则不直接改变原始张量。

示例展示

x = torch.zeros(2, 1, 2, 1, 2)
y = torch.zeros(2, 1, 2, 1, 2)
x_t = x.squeeze_(1)
y_t = y.squeeze(1)

print('squeeze原始张量:',y.size())
print('squeeze变化张量:',y_t.size())
print('squeeze_原始张量:',x.size())
print('squeeze_变化张量:',x_t.size())

在这里插入图片描述

x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([1, 2, 3, 4])
x_t = x.unsqueeze(1)
y_t = y.unsqueeze_(1)

print('unsqueeze原始张量:',x.size())
print('unsqueeze变化张量:',x_t.size())
print('unsqueeze_原始张量:',y.size())
print('unsqueeze_变化张量:',y_t.size())

在这里插入图片描述

以上就是有关squeeze和unsqueeze的知识点,有啥问题欢迎提问!

你可能感兴趣的:(python知识,pytorch学习,深度学习,python,机器学习,pytorch)