作用:移除指定或所有维数为1的维度,从而得到维度减少的张量
解释一下:
x=torch.zeros(5,1,1,1)
print(x)
'输出'
tensor([[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]]])
举个极端点的例子,这是一个4维的数组,除了第0个维度之外每个维度的维数均为1
也就是说,每一个0都被3个括号括着,这显然不太合理
下面调用squeeze函数:
y = x.squeeze()
print(y)
print(y.shape)
'输出'
tensor([0., 0., 0., 0., 0.])
torch.Size([5])
瞬间一系列的括号都没有了,是不是看着舒服了许多?
进一步:
y = x.squeeze(1)
print(y)
print(y.shape)
'输出'
tensor([[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]]])
torch.Size([5, 1, 1])
这里添加了参数1,这样就只压缩了第1个维度(计数从0开始),一个0被2个括号括着
但压缩的前提是,该张量必须有维数为1的维度,比如:
y = x.squeeze(0)
print(y)
print(y.shape)
a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = a.squeeze()
print(b)
print(b.shape)
'输出'
tensor([[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]],
[[[0.]]]])
torch.Size([5, 1, 1, 1])
tensor([[1, 1, 1],
[2, 2, 2]])
torch.Size([2, 3])
y和b相对于x和a均没有发生变化,原因就是:x的第0个维度,维数不是1;a中更是没有维数为1的维度
另外:x.squeeze() 或者 torch.squeeze(x) 都不会让x发生改变
y = x.squeeze()
print(x.shape)
print(y.shape)
'输出'
torch.Size([5, 1, 1, 1])
torch.Size([5])
作用:在张量的制定维度插入新的维度得到维度提升的张量
举个例子:
x= torch.zeros(5)
print(x)
print(x.shape)
'输出'
tensor([0, 0, 0, 0, 0])
torch.Size([5])
一维张量,总共5个0,接下来依次操作:
y = x.unsqueeze(dim=0)
print(y)
print(y.shape)
y = x.unsqueeze(dim=1)
print(y)
print(y.shape)
z = y.unsqueeze(dim=2)
print(z)
print(z.shape)
'输出'
tensor([[0., 0., 0., 0., 0.]])
torch.Size([1, 5])
tensor([[0.],
[0.],
[0.],
[0.],
[0.]])
torch.Size([5, 1])
tensor([[[0.]],
[[0.]],
[[0.]],
[[0.]],
[[0.]]])
torch.Size([5, 1, 1])
把第0维进行扩张,就是在最外面加了一个括号
把第1维进行扩张,就是把里面的每个元素元素(也可以理解成是,扩充后的第1维,也就是0.)都加一个括号
继续套,选择dim=2,还是把最内层的了(也可以理解成是,扩充后的第2维,也就是0.),都加一个括号
再举个例子:
a = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(a)
print(a.shape)
b = a.unsqueeze(dim=0)
print(b)
print(b.shape)
b = a.unsqueeze(dim=1)
print(b)
print(b.shape)
b = a.unsqueeze(dim=2)
print(b)
print(b.shape)
'输出'
tensor([[1, 1, 1],
[2, 2, 2]])
torch.Size([2, 3])
tensor([[[1, 1, 1],
[2, 2, 2]]])
torch.Size([1, 2, 3])
tensor([[[1, 1, 1]],
[[2, 2, 2]]])
torch.Size([2, 1, 3])
tensor([[[1],
[1],
[1]],
[[2],
[2],
[2]]])
torch.Size([2, 3, 1])
怎么套的括号,是不是一目了然~
同样:x.unsqueeze() 或者 torch.unsqueeze(x) 都不会让x发生改变
x = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(x)
y = x.unsqueeze(dim=0)
print(y.shape)
print(x.shape)
'输出'
torch.Size([1, 2, 3])
torch.Size([2, 3])