squeeze()和unsqueeze()的使用

目录

说明

定义数据

降维

增维

定位降维


说明 

        (如需制作课程资源(幻灯片、实训手册、视频等)请私信给我)

squeeze()是降维函数,unsqueeze()是增维函数。具体使用方法如下:

定义数据

import torch
m = torch.Tensor(1,2,3)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00],
         [3.0674e-41, 1.1210e-43, 0.0000e+00]]])

维度为

torch.Size([1, 2, 3])

降维

squeeze()缺省函数值,是去除所有维度为1的维度。定义函数值,在后面的【定位增维】中介绍。

m = m.squeeze()
m
tensor([[1.0769e-02, 3.0674e-41, 6.9797e+00],
        [3.0674e-41, 1.1210e-43, 0.0000e+00]])

维度为

torch.Size([2, 3])

增维

dim为unsqueeze()的函数值,这里值为0,表示在第1维插入一个维度

m = m.unsqueeze(0)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00],
         [3.0674e-41, 1.1210e-43, 0.0000e+00]]])

维度为

torch.Size([1, 2, 3])

dim为1,表示在第2维插入一个维度

m = m.unsqueeze(1)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00]],

        [[3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([2, 1, 3])

dim为2,表示在第3维插入一个维度

m = m.unsqueeze(2)
m
tensor([[[1.0769e-02],
         [3.0674e-41],
         [6.9797e+00]],

        [[3.0674e-41],
         [1.1210e-43],
         [0.0000e+00]]])
torch.Size([2, 3, 1])

dim为-1,表示在倒数第1维插入一个维度

m = m.unsqueeze(-1)
m
tensor([[[1.0769e-02],
         [3.0674e-41],
         [6.9797e+00]],

        [[3.0674e-41],
         [1.1210e-43],
         [0.0000e+00]]])
torch.Size([2, 3, 1])

dim为-2,表示在倒数第2维插入一个维度

m = m.unsqueeze(-2)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00]],

        [[3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([2, 1, 3])

dim为-3,表示在倒数第3维插入一个维度

m = m.unsqueeze(-3)
m
tensor([[[1.0769e-02, 3.0674e-41, 6.9797e+00],
         [3.0674e-41, 1.1210e-43, 0.0000e+00]]])
torch.Size([1, 2, 3])

定位降维

squeeze() 缺省值为去掉所有维度为1的维度。

m = torch.Tensor(1,2,1,3,1)
m
tensor([[[[[6.9997e+00],
           [3.0674e-41],
           [7.0005e+00]]],


         [[[3.0674e-41],
           [8.9683e-44],
           [0.0000e+00]]]]])
torch.Size([1, 2, 1, 3, 1])
m = m.squeeze()
m
tensor([[6.9997e+00, 3.0674e-41, 7.0005e+00],
        [3.0674e-41, 8.9683e-44, 0.0000e+00]])
torch.Size([2, 3])

squeeze(0)若第一维度为1则去除第一维度(第一维度不为1则不去除)

m = torch.Tensor(1,2,1,3,1)
m
tensor([[[[[7.8503e+00],
           [3.0674e-41],
           [6.9897e+00]]],


         [[[3.0674e-41],
           [8.9683e-44],
           [0.0000e+00]]]]])
torch.Size([1, 2, 1, 3, 1])
m = m.squeeze(0)
m
tensor([[[[7.8503e+00],
          [3.0674e-41],
          [6.9897e+00]]],


        [[[3.0674e-41],
          [8.9683e-44],
          [0.0000e+00]]]])
torch.Size([2, 1, 3, 1])

squeeze(1)若第二维度为1则去除第二维度(第二维度不为1则不去除)

m = m.squeeze(1)
m
tensor([[[7.8503e+00],
         [3.0674e-41],
         [6.9897e+00]],

        [[3.0674e-41],
         [8.9683e-44],
         [0.0000e+00]]])
torch.Size([2, 3, 1])

squeeze(2)若第三维度为1则去除第二维度(第三维度不为1则不去除)

m = m.squeeze(2)
m
tensor([[7.8503e+00, 3.0674e-41, 6.9897e+00],
        [3.0674e-41, 8.9683e-44, 0.0000e+00]])
torch.Size([2, 3])

你可能感兴趣的:(NLP,深度学习,python,pytorch)