【Pytorch】torch.squeeze和torch.unsqueeze函数

【Pytorch】torch.squeeze和torch.unsqueeze

torch.squeeze(input, int)

返回一个张量,其可以在原来的input的shape基础之上,压缩维度,比如输入的input的shape为torch.Size([31, 1, 8]),调用torch.squeeze(input, 1),意思就是将原来的input的第二列**(0,1,2三列)**“删除”,返回的tensor的shape为 torch.Size([31, 8])

注意点 int的范围:假设原来shape有 x 列,则int可以取 [-x, x-1]

>>> a = torch.rand([32, 1, 8])
>>> a = torch.squeeze(a, 1)
>>> print(a.shape)
torch.Size([32, 8])

更详细的请参考官网:torch.squeeze — PyTorch 1.12 documentation

torch.unsqueeze(input, int)

返回一个张量,其可以在原来的input的shape基础之上,增加维度,比如输入的input的shape为torch.Size([31, 8]),调用torch.squeeze(input, 1),意思就是将原来的input的第二列**(0,1,2三列)**“替换”为1,返回的tensor的shape为 torch.Size([31, 1, 8])

注意点 int的范围:假设原来shape有x列,则int可以取 [-(x+1), x]

>>> a = torch.rand([32, 8])
>>> a = torch.unsqueeze(a, 1)
>>> print(a.shape)
torch.Size([32, 1, 8])

更详细的请参考官网:torch.unsqueeze — PyTorch 1.12 documentation

你可能感兴趣的:(Pytorch,python,pytorch)