⭐能让人成长的,从来不是停留在舒适区,再迈一步,再试一次,你总能发现一个更强大的自己。
一、squeeze()
二、unsqueeze()
先来看看官方的Docs,链接在这里:torch.squeeze — PyTorch master documentation
这个函数就是返回一个张量,将input中大小为1的维度都删除。例如:假设一个输入的shape为(AX1XBXCX1XD),则其output的shape为:(AXBXCXD)。若给定维度,则删除给定的维度,但是只有大小为1的维度才会被删除。下面举一个例子:
import torch
x = torch.zeros(1,2,1,2,3)
print(f"x:{x.shape}")
y = torch.squeeze(x) # x中大小为1的维度都删除
print(f"y:{y.shape}")
z = torch.squeeze(x,0) # 删除第一个维度
print(f"z:{z.shape}")
w = torch.squeeze(x,-3) # -1是指倒数第一个维度,-2是指倒数第二个,依次类推
print(f"w:{w.shape}")
g = torch.squeeze(x,1) # 只有大小为1的维度才可以被删除
print(f"g:{g.shape}")
结果如下:
x:torch.Size([1, 2, 1, 2, 3])
y:torch.Size([2, 2, 3])
z:torch.Size([2, 1, 2, 3])
w:torch.Size([1, 2, 2, 3])
g:torch.Size([1, 2, 1, 2, 3])
官方Docs如下,链接在这里:torch.unsqueeze — PyTorch master documentation
unsqueeze()就是给指定位置加上维数为一的维度。返回的张量与该张量共享相同的基础数据。例子如下:
import torch
x = torch.zeros(1,2,1,2,3)
print(f"x:{x.shape}")
y = torch.unsqueeze(x,1) # 在第二维增加一个维度,该维度的大小为1
print(f"y:{y.shape}")
结果如下:
x:torch.Size([1, 2, 1, 2, 3])
y:torch.Size([1, 1, 2, 1, 2, 3])