pytorch中squeeze()和unsqueeze()函数意义

1 unsqueeze()

函数功能:与squeeze()函数功能相反,用于添加维度;

>>> a = torch.Tensor(3)
>>> a
tensor([1.7718e+28, 1.0509e-38, 0.0000e+00])
>>> a.unsqueeze(0) # 扩展第一个维度
tensor([[1.7718e+28, 1.0509e-38, 0.0000e+00]])
>>> a.unsqueeze(1) #扩展第二个维度
tensor([[1.7718e+28],
        [1.0509e-38],
        [0.0000e+00]])

2 squeeze()

        squeeze本身有挤压的意思;

函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用;

        其中squeeze(0)代表若第一维度值为1则去除第一维度;

        squeeze(1)代表若第二维度值为1则去除第二维度;

>>> a = torch.Tensor(3,2)
>>> a
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.unsqueeze(0)
tensor([[[-6.5850e+34,  4.5759e-41],
         [-6.5850e+34,  4.5759e-41],
         [ 0.0000e+00,  0.0000e+00]]])
>>> a.squeeze(0) # 第一维度会被缩减
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.squeeze(1)# 第二维度不会被缩减
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])
>>> a.unsqueeze(-1) # 表示最后一个维度
tensor([[[-6.5850e+34],
         [ 4.5759e-41]],

        [[-6.5850e+34],
         [ 4.5759e-41]],

        [[ 0.0000e+00],
         [ 0.0000e+00]]])
>>> a.squeeze(-1)
tensor([[-6.5850e+34,  4.5759e-41],
        [-6.5850e+34,  4.5759e-41],
        [ 0.0000e+00,  0.0000e+00]])

参考: 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍_Jaborie203的博客-CSDN博客_pytorch unsqueeze

你可能感兴趣的:(编程,pytorch)