Pytorch改变Tensor维度

在pytorch中,有比较多的函数可以对Tensor的维度进行改变,下面笔者就简单列出一些。

1.torch.squeeze()/torch.unsqueeze()
squeeze函数是对张量的维度进行压缩,去掉维数为1的维度;相反unsqueeze函数是对张量进行维度扩张。

多说无益,赶紧上马:

import torch

#这里随机产生一个多维张量shape = ([3, 1, 2, 1, 4, 1])
a = torch.randn(3,1,2,1,4,1)
b = a.squeeze()
print(b.size())            # output:torch.Size([3, 2, 4])

#对特定的维度进行压缩
c = a.squeeze(1)
print(c.size())          # output: torch.Size([3, 2, 1, 4, 1])

#如果要对两个维度进行压缩,则一定要一个一个进行压缩,不能一次同时压缩
#比如对a的1、3维度压缩,如下。这里不能a.squeeze(1, 3),程序会报错!
d = a.squeeze(1).squeeze(2)  #是在c的第2维度上进行的压缩
print(d.size())         #output:torch.Size([3, 2, 4, 1])


#同样,unsqueeze也类似
# b :torch.Size([3, 2, 4])
e = b.unsqueeze(1)
print(e.size())            #output:torch.Size([3, 1, 2, 4])

f = b.unsqueeze(1).unsqueeze(dim = 3)
print(f.size())            # output : torch.Size([3, 1, 2, 1, 4])

2.torch.view()

把原先Tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的Tensor。

多说无益,赶紧上马:

import torch
a = torch.randn(2,3,4)
print(a)   # output: tensor([[[-0.8822,  0.6797,  0.5335,  0.1103],
                              [ 0.6791, -2.1690, -0.6625, -0.2989],
                              [-1.3134, -1.1234, -0.7303,  2.1314]],

                             [[-0.8697,  0.8352,  0.9058, -1.2924],
                              [ 1.3043, -0.8773,  0.5054,  0.4219],
                              [-1.0243, -2.5556, -0.6324, -1.6356]]])
b = a.view(1, 24)
print(b)    # output: tensor([[-0.8822,  0.6797,  0.5335,  0.1103,  0.6791, -2.1690, -0.6625, -0.2989,
                          #   -1.3134, -1.1234, -0.7303,  2.1314, -0.8697,  0.8352,  0.9058, -1.2924,
                         #   1.3043, -0.8773,  0.5054,  0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(b.size())    # output : torch.Size([1, 24])

c = a.view(3, 8)  # output : tensor([[-0.8822,  0.6797,  0.5335,  0.1103,  0.6791, -2.1690, -0.6625, -0.2989],
                                 #   [-1.3134, -1.1234, -0.7303,  2.1314, -0.8697,  0.8352,  0.9058, -1.2924],
                                 # [ 1.3043, -0.8773,  0.5054,  0.4219, -1.0243, -2.5556, -0.6324, -1.6356]])
print(c.size())   # output: torch.Size([3, 8])

d = a.view(3, 2, 4)
print(d.size())   # output : torch.Size([3, 2, 4])

e = a.view(3, 2, 2, 2) # output: tensor([[[[-0.8822,  0.6797],
                                           [ 0.5335,  0.1103]],

                                          [[ 0.6791, -2.1690],
                                           [-0.6625, -0.2989]]],


                                         [[[-1.3134, -1.1234],
                                           [-0.7303,  2.1314]],

                                          [[-0.8697,  0.8352],
                                           [ 0.9058, -1.2924]]],


                                         [[[ 1.3043, -0.8773],
                                           [ 0.5054,  0.4219]],

                                          [[-1.0243, -2.5556],
                                           [-0.6324, -1.6356]]]])
print(e.size())  # output: torch.Size([3, 2, 2, 2])

3.permute
将Tensor维度进行换位。

import torch
a = torch.randn(2,3,4)
print(a)   # output: tensor([[[-0.8822,  0.6797,  0.5335,  0.1103],
                              [ 0.6791, -2.1690, -0.6625, -0.2989],
                              [-1.3134, -1.1234, -0.7303,  2.1314]],

                             [[-0.8697,  0.8352,  0.9058, -1.2924],
                              [ 1.3043, -0.8773,  0.5054,  0.4219],
                              [-1.0243, -2.5556, -0.6324, -1.6356]]])
>>> b = a.permute(0,2,1)
>>> print(b)
tensor([[[-0.8822,  0.6791, -1.3134],
         [ 0.6797, -2.1690, -1.1234],
         [ 0.5335, -0.6625, -0.7303],
         [ 0.1103, -0.2989,  2.1314]],

        [[-0.8697,  1.3043, -1.0243],
         [ 0.8352, -0.8773, -2.5556],
         [ 0.9058,  0.5054, -0.6324],
         [-1.2924,  0.4219, -1.6356]]])
>>> print(b.size())
torch.Size([2, 4, 3])

4.stack/cat

torch.stack(sequence, dim=0, out=None),
torch.cat(sequence, dim=0, out=None),

sequence表示Tensor列表,dim表示拼接的维度,stack是建立一个新的维度,然后再在该纬度上进行拼接;而cat是在已有的维度上拼接。

不理解?直接上马:

>>> import torch
>>> t1 = torch.tensor([1,1,1])
>>> t2 = torch.tensor([2,2,2])
>>> t3 = torch.tensor([3,3,3])
>>> torch.cat((t1,t2,t3),dim=0)
tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])

>>> torch.stack((t1,t2,t3), dim=0)
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

>>> torch.cat((t1.unsqueeze(0), t2.unsqueeze(0),t3.unsqueeze(0)),dim=0)
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])

>>> torch.stack((t1,t2,t3),dim=1)
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

>>> torch.cat((t1.unsqueeze(1), t2.unsqueeze(1), t3.unsqueeze(1)), dim=1)
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

通过上面的示例可以看出,先使用unsqueeze对Tensor进行维度扩张,然后再cat便可以得到与stack一样的结果。

你可能感兴趣的:(pytorch,python)