【Pytorch学习笔记】torch.cat() 与 torch.stack()

torch.stack()

stack在英文中有“堆叠的意思”。所以stack通常是把一些低纬(二维)的tensor堆叠为一个高维(三维)的tensor。
stack()官方解释:torch.stack[source] → Tensor :
函数目的: 沿着一个新维度对输入张量序列进行拼接 。其中序列中所有的 张量 都应该为相同形状。

outputs = torch.stack(inputs, dim=0)  # → Tensor

参数:

  • inputs : 待连接的张量序列。
    注:python的序列数据只有list和tuple
  • dim : 新的维度, 必须在0到len(outputs)之间。
    注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

例子

  1. 准备2个tensor数据,每个的shape都是[3,3]
T1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])
  1. 测试stack函数
R0 = torch.stack((T1, T2), dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
"""
R0:
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
R0.shape:
 torch.Size([2, 3, 3])
"""

R1 = torch.stack((T1, T2), dim=1)
print("R1.shape:\n", R1.shape)
"""
R1:
 tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]])
R1.shape:
 torch.Size([3, 2, 3])

"""

R2 = torch.stack((T1, T2), dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
"""
R2:
 tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])
R2.shape:
 torch.Size([3, 3, 2])

"""
R3 = torch.stack((T1, T2), dim=3)
print("R3:\n", R3)
print("R3.shape:\n", R3.shape)
"""
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""

torch.cat()

一般torch.cat()是为了把函数torch.stack()得到tensor进行拼接而存在的。torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。
函数目的:
在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=0) → Tensor

参数:

  • inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列。
  • dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

重点:

  • 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  • 维度不可以超过输入数据的任一个张量的维度

例子

  1. 准备数据,每个的shape都是[2,3]
x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
print("x1:\n", x1)
print("x1.shape:\n", x1.shape)
'''
x1:
 tensor([[11, 21, 31],
        [21, 31, 41]], dtype=torch.int32)
x1.shape:
 torch.Size([2, 3])
'''
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
print("x2:\n", x2)
print("x2.shape:\n", x2.shape)
'''
x2:
 tensor([[12, 22, 32],
        [22, 32, 42]])
x2.shape:
 torch.Size([2, 3])
'''
  1. 合成inputs
inputs = [x1, x2]
print("inputs:\n", inputs)
'''
inputs:
 [tensor([[11, 21, 31],
        [21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32],
        [22, 32, 42]])]
'''
  1. 查看结果, 测试不同的dim拼接结果
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
 tensor([[11, 21, 31],
        [21, 31, 41],
        [12, 22, 32],
        [22, 32, 42]])
R0.shape:
 torch.Size([4, 3])
'''

R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
 tensor([[11, 21, 31, 12, 22, 32],
        [21, 31, 41, 22, 32, 42]])
R1.shape:
 torch.Size([2, 6])
'''

R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''

总结

  1. torch.stack()是新增一维,属于增维操作(22,22 → 222)。
  2. torch.cat()是在特定维度上进行拼接(22, 22 → 2*4)。

参考链接:https://blog.csdn.net/qq_40507857/article/details/119854085

你可能感兴趣的:(深度学习,pytorch,学习,深度学习)