Pytorch 之squeeze和unsqueeze用法

Pytorch使用中常会用到torch.squeeze()和torch.unsqueeze()函数:

查找资料相关记录如下:

参考博客:https://blog.csdn.net/qq_39709535/article/details/81841426

1. torch.squeeze(input, dim = None, out = None): 返回一个tensor,当dim不设值时,去掉输入的tensor的所有维度为1的维度; 当dim为某一整数(0<=dim 另外,当input是一维的时候,squeeze不变

>>> x = torch.zeros(1,1,2,1,3)
>>> x.dim()
5
>>> torch.squeeze(x).size()#去掉dim=1的维度
torch.Size([2, 3])
>>> torch.squeeze(x,0).size()  # dim=0表示第一维,且第一维的维度为1,所以去掉
torch.Size([1, 2, 1, 3])
>>> torch.squeeze(x,3).size()
torch.Size([1, 1, 2, 3])
>>> torch.squeeze(x,2).size()  # dim=2,第三维的维度为2!=1,所以不变
torch.Size([1, 1, 2, 1, 3])
 

2. torch.unqueeze(input, dim, out=None): 和squeeze作用相反,unsqueeze()在dim维插入一个维度为1的维,例如原来x是n×m维的,torch.unqueeze(x,0)这返回1×n×m的tensor

 

>>> x = torch.tensor([1,2,3])#dim=1,即(3)
>>> torch.unsqueeze(x,1)#变为(3,1)的矩阵
tensor([[ 1],
        [ 2],

 

你可能感兴趣的:(DL)