torch.squeeze
torch.squeeze(input, dim = None, *, out = None)
. input:输入的张量
. dim 选择需要降维的维度,默认是None
为什么要降维
如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。在多维张量中,如果某一个维度是1,那么这个维度是为了扩充维度,所以为了加快计算,进行降维操作时可以去掉1的维度。
import torch
A = torch.ones((1,2,3,1,4,2))
A.shape
torch.Size([1, 2, 3, 1, 4, 2])
B = torch.squeeze(A,dim=0)
B.shape
torch.Size([2, 3, 1, 4, 2])
C = torch.squeeze(A,dim=3)
C.shape
torch.Size([1, 2, 3, 4, 2])
D = torch.squeeze(A,dim=2)
D.shape
torch.Size([1, 2, 3, 1, 4, 2])
E = torch.squeeze(A)
E.shape
torch.Size([2, 3, 4, 2])
不指定维度,则会将所有为1 的维度全部降维,保留不是1 的维度
torch.unsqueeze 是为了升维
torch.unsqueeze(input,dim)
input: 插入的张量
dim: 指定在某个维度进行升维
W = torch.ones((2,3,5))
W.shape
torch.Size([2, 3, 5])
M = torch.unsqueeze(W, dim = 0)
M.shape
torch.Size([1, 2, 3, 5])