Pytorch中 stack和cat, squeeze 和 unsqueeze 之间的差别

目录

1、torch.cat

2、torch.stack

3、torch.squeeze()

4、 torch.unsqueeze()


1、torch.cat

torch.cat(tensors, dim=0, *, out=None) → Tensor

官方解释:利用给定的维度连接给定的数组序列(cat代表concatenate),所有数组必须具有相同的形状(连接维度除外)或为空。相当于按指定维度将数组进行拼接

参数解释:

  • tensors:要连接的数组序列(元组tuple或者列表list)
  • dim:数组连接的维度
  • out:输出数组(一般用不到,如果有输出,则可以直接进行赋值操作)

注意:
tensors输入的必须是数组序列,不能是单个数组;
②输入的数组序列除了dim维度,其他维度必须形状相同

对数据沿着某一维度进行拼接。cat后数据的总维数不变, 利用torch.cat()沿dim拼接,在形状上看相当于对dim进行相加,其余维度大小不变,利用这个思想,可以很容易理解高维数组的拼接

'''
cat之后的数据 z[0,:,:]是x的值,z[1,:,:]是y的值。
'''
x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.cat((x,y))  #默认dim=0
print(z.shape)
#torch.Size([2, 2, 3])

"""
dim=0,为以第一维为基准拼接,对于一个张量的维度,有几个放括号就是几维,
下面例子a0与a1均为4维张量,因此以第0维拼接就是将第一个中括号内的内容进行拼接。
最终的尺度大小为(2,1,2,4)
"""
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.Size([1, 1, 2, 4])

torch.cat((a0,a1),dim=0)

tensor([[[[1, 1, 1, 1],
          [2, 2, 2, 2]]],

        [[[3, 3, 3, 3],
          [4, 4, 4, 4]]]])

torch.Size([2, 1, 2, 4])

"""
dim=1,以第1维进行拼接,将第二个括号内的内容进行拼接
"""

torch.cat((a0,a1),dim=1)

tensor([[[[1, 1, 1, 1],
          [2, 2, 2, 2]],

         [[3, 3, 3, 3],
          [4, 4, 4, 4]]]])
torch.Size([1, 2, 2, 4])

"""
dim=2,类推
"""
torch.cat((a0,a1),dim=2)
tensor([[[[1, 1, 1, 1],
          [2, 2, 2, 2],
          [3, 3, 3, 3],
          [4, 4, 4, 4]]]])
torch.Size([1, 1, 4, 4])

"""
dim=3,类推
"""
torch.cat((a0,a1),dim=3)
tensor([[[[1, 1, 1, 1, 3, 3, 3, 3],
          [2, 2, 2, 2, 4, 4, 4, 4]]]])
torch.Size([1, 1, 2, 8])

2、torch.stack

torch.stack(tensors, dim=0, *, out=None) → Tensor

官方解释:沿着新的维度连接一系列数组,所有的数组都需要具有相同的大小。相当于先将多个n维数组进行扩维操作,然后再拼接为一个n+1维的数组其中最关键的是stack之后的数据的size会多出一个维度,而cat则不会

参数解释:

  • tensors:要连接的数组序列(元组tuple或者列表list)
  • dim:要插入的维度,大小必须介于0和需要拼接的数组维数之间(dim最大不超过数组的维数)
  • out:输出数组(与cat()类似,一般用不到)

注意:
①与cat类似,必须输入数组序列,不能是单个数组;
②输入的所有数组序列形状(尺寸)必须一致(这里与cat有区别)。

首先会将要拼接的tensor按dim增加,比如说对于stack的操作来说,先将a,b,c的维度变为(1,3,3),然后再按dim进行torch.cat的操作。

从上边两个例子,可以看出,对于torch.stack来说,会先将原始数据维度扩展一维,然后再按照dim维度进行torch.cat拼接,具体拼接操作同torch.cat类似

import torch
a=torch.randn((1,3,4,4)) #[N,c,w,h]
b=torch.stack((a,a))   # stack默认 dim = 0
# (2, 1, 3, 4, 4)

c=torch.stack((a,a), 1) # stack dim = 1
# (1, 2, 3, 4, 4)

d=torch.stack((a,a),2) # stack dim = 2
# (1, 3, 2, 4, 4)


'''
所以stack的之后的数据也就很好理解了,z[0,...]的数据是x,z[1,...]的数据是y。
'''
x=torch.randn((1,2,3))
y=torch.randn((1,2,3))
z=torch.stack((x,y))#默认dim=0
print(z.shape)
#torch.Size([2, 1, 2, 3])

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.Size([1, 1, 2, 4])

torch.stack((a0,a1),dim=0)

tensor([[[[[1, 1, 1, 1],
           [2, 2, 2, 2]]]],

        [[[[3, 3, 3, 3],
           [4, 4, 4, 4]]]]])
torch.Size([2, 1, 1, 2, 4])

torch.stack((a0,a1),dim=1)
tensor([[[[[1, 1, 1, 1],
           [2, 2, 2, 2]]],

         [[[3, 3, 3, 3],
           [4, 4, 4, 4]]]]])
torch.Size([1, 2, 1, 2, 4])
# cat和stack不指定维度的时候默认都是0
a = torch.tensor(torch.rand(1,2,3))
>>>tensor([[[0.0168, 0.5604, 0.5117],
         [0.7407, 0.1112, 0.5702]]])
         
torch.cat((a,a))    
>>>tensor([[[0.0168, 0.5604, 0.5117],
         [0.7407, 0.1112, 0.5702]],
        [[0.0168, 0.5604, 0.5117],
         [0.7407, 0.1112, 0.5702]]])    # 默认为0维,所以最终shape为(2,2,3)
 
torch.stack((a,a),0)   
>>>tensor([[[[0.0168, 0.5604, 0.5117],
          [0.7407, 0.1112, 0.5702]]],
        [[[0.0168, 0.5604, 0.5117],       # 相当于原有的维度(1,2,3)->(1,1,2,3) 再在0维处进行拼接
          [0.7407, 0.1112, 0.5702]]]])    # 最终shape为(2,1,2,3) 
          
torch.stack((a,a),1)
>>>tensor([[[[0.0168, 0.5604, 0.5117],
          [0.7407, 0.1112, 0.5702]],
         [[0.0168, 0.5604, 0.5117],     # 相当于原有的维度(1,2,3)->(1,1,2,3) 再在一维进行拼接
          [0.7407, 0.1112, 0.5702]]]])  # 最终shape为(1,2,2,3) 

torch.stack((a,a),2)
tensor([[[[0.0168, 0.5604, 0.5117],
          [0.0168, 0.5604, 0.5117]],
         [[0.7407, 0.1112, 0.5702],      # 相当于原有的维度(1,2,3)->(1,2,1,3) 再在第二维维进行拼接
          [0.7407, 0.1112, 0.5702]]]])   # 最终shape为(1,2,2,3) 

# 由于stakc 拼接的时候默认新增了一个维度。所以可以在第四维进行拼接
torch.stack((a,a),3)   # cat在第三维的时候会报错,因为他只有三维
tensor([[[[0.0168, 0.0168],
          [0.5604, 0.5604],
          [0.5117, 0.5117]],
         [[0.7407, 0.7407],
          [0.1112, 0.1112],   # 相当于原有的维度(1,2,3)->(1,2,3,1) 再在第三维维进行拼接
          [0.5702, 0.5702]]]])    # 最终shape维(1,2,3,2)

具体怎么堆叠:

import torch
a=torch.arange(1,7).reshape((3,2))
#tensor([[1, 2],
#        [3, 4],
#        [5, 6]])
b=torch.arange(10,70,10).reshape((3,2))
#tensor([[10, 20],
#        [30, 40],
#        [50, 60]])
c=torch.arange(100,700,100).reshape((3,2))
#tensor([[100, 200],
#        [300, 400],
#        [500, 600]])
d=torch.stack((a,b,c)) #(3, 3, 2)
#tensor([[[  1,   2],
#         [  3,   4],
#         [  5,   6]],

#        [[ 10,  20],
#         [ 30,  40],
#         [ 50,  60]],

#        [[100, 200],
#         [300, 400],
#         [500, 600]]])
e=torch.stack((a,b,c),1) #(3, 3, 2)
#tensor([[[  1,   2],
#         [ 10,  20],
#         [100, 200]],

#        [[  3,   4],
#         [ 30,  40],
#         [300, 400]],

#        [[  5,   6],
#         [ 50,  60],
#         [500, 600]]])
f=torch.stack((a,b,c),2) #(3, 2, 3)
#tensor([[[  1,  10, 100],
#         [  2,  20, 200]],

#        [[  3,  30, 300],
#         [  4,  40, 400]],

#        [[  5,  50, 500],
#         [  6,  60, 600]]])

3、torch.squeeze()

torch.squeeze(input, dim=None, out=None)

作用:降维, 去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。

1、将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),不给定dim时,那么默认去掉所有维度为1的维, 输出形状就为: (A×B×C×D)

2、当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B×1×C×1×D),

  • squeeze(input, 0) 将会保持张量不变 (A×1×B×1×C×1×D),因为 维度0的值A大于1
  • squeeze(input, 1),形状会变成 (A×B×1×C×1×D),只去掉dim指定的维度,该维度为大小1
  • -1,去除最后维度值为1的维度

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
参数:

  • input (Tensor) – 输入张量
  • dim (int, optional) – 如果给定,则input只会在给定维度挤压
  • out (Tensor, optional) – 输出张量

为何只去掉 1 呢?
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。

import torch
b = torch.Tensor(1, 3, 1, 5)
b.shape
Out[28]: torch.Size([3, 5])

# 不加参数,去掉所有为元素个数为1的维度, 也即维度为 1 的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1的维度
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([3, 1, 5])

# 加上参数,去掉第二维的元素为1的维度,不起作用,因为第二维有3个元素(维度为3)
b_ = b.squeeze(1)
b_.shape 
Out[32]: torch.Size([1, 3, 1, 5])

4、 torch.unsqueeze()

torch.unsqueeze(input, dim, out=None)

作用:扩展维度, 增加大小为1的维度,也就是返回一个新的张量,对输入的指定位置插入维度 1且必须指明维度

   返回一个新的张量,对输入的既定位置插入维度 1

注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

参数:

  • tensor (Tensor) – 输入张量
  • dim (int) – 插入维度的索引
  • out (Tensor, optional) – 结果张量
import torch

b = torch.Tensor(3, 5)

# 在0维增加一个维度
b_ = b.unsqueeze(0)
b_.shape
# torch.Size([1, 3, 5])

在第1维增加一个维度
b_ = b.unsqueeze(1)
b_.shape
# torch.Size([3, 1, 5])

你可能感兴趣的:(AI之路,-,Face,深度学习,python,人工智能)