torch.cat函数在二维,三维数据中拼接时候 dim维度 理解

文章目录

  • 一、先来看看`torch.cat`函数的参数以及参数值
  • 二、torch.cat在二维数据中示例说明
    • 1. 在二维数据中 dim 取值 (0,1)
    • 2. 发现了什么问题
    • 总结:
  • 三、从二维数据延伸到三维数据
    • 1. 三维数据`dim`取值为(0,1,2),依次类推
    • 2. 三维数据与二维数据的相关性
    • 3 . 逐一验证

一、先来看看torch.cat函数的参数以及参数值

torch中的cat函数用于沿着指定维度将张量连接起来。具体而言,如果给定一个包含多个张量的序列,通过指定dim参数可以将它们沿着指定维度连接在一起。

函数的常见形式如下:

torch.cat(seq, dim=0, out=None)

其中:

seq:一个Tensor序列,即要拼接的多个张量
dim:连接的维度,默认为0(按行拼接)。可以是任何整数值,具体取值依赖于输入张量的维度。例如,对于二维张量,dim=0表示按行拼接,dim=1表示按列拼接。
out:输出张量。如果指定了此参数,则结果会被写入该张量中,不会创建新的张量。如果没有指定,则会创建新的张量作为结果返回。

二、torch.cat在二维数据中示例说明

1. 在二维数据中 dim 取值 (0,1)

举个例子,假设我们有两个二维张量A和B:

import torch

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])

如果要将它们按行拼接(即在第0维度上拼接),则可以这样做:

python
C = torch.cat((A, B), dim=0)
print(C)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

如果要将它们按列拼接(即在第1维度上拼接),则可以这样做:

python
C = torch.cat((A, B), dim=1)
print(C)
# 输出:
# tensor([[1, 2, 5, 6],
#         [3, 4, 7, 8]])

总之,cat函数可以用于将张量沿着指定的维度连接在一起,非常灵活。需要根据具体情况选择合适的dim参数值来实现多种拼接方式。

2. 发现了什么问题

  1. 按行拼接(或者叫列合并)是指将多个二维张量沿着第0维度即行)拼接在一起,形成一个更大的二维张量。例如,假设有两个二维张量A和B:
A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

那么将它们按行拼接之后得到的结果就是:

C = [[1, 2],
     [3, 4],
     [5, 6],
     [7, 8]]

可以看到,新的张量C比原来的张量A和B都要,因为它包含了两个输入张量中所有的行

  1. 按列拼接(或者叫行合并)是指将多个二维张量沿着第1维度(即)拼接在一起,形成一个更宽的二维张量。例如,假设有两个二维张量A和B:
A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

那么将它们按列拼接之后得到的结果就是:

C = [[1, 2, 5, 6],
     [3, 4, 7, 8]]

可以看到,新的张量C比原来的张量A和B都要宽,因为它包含了两个输入张量中所有的列

总结:

  1. dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据所有行
A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

那么将它们按行拼接之后得到的结果就是:

C = [[1, 2],
     [3, 4],
     [5, 6],
     [7, 8]]
  1. dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据所有列
A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

那么将它们按列拼接之后得到的结果就是:

C = [[1, 2, 5, 6],
     [3, 4, 7, 8]]

三、从二维数据延伸到三维数据

1. 三维数据dim取值为(0,1,2),依次类推

在三维张量中,dim参数的取值范围为0、1、2,具体的含义如下:

dim=0:表示沿着第0维度进行拼接。这意味着将两个包含多个矩阵的三维张量连接起来,形成一个更高的三维张量。
dim=1:表示沿着第1维度进行拼接。这意味着将两个包含多个行向量的三维张量连接起来,形成一个更宽的三维张量。
dim=2:表示沿着第2维度进行拼接。这意味着将两个包含多个列向量的三维张量连接起来,形成一个更深的三维张量。

2. 三维数据与二维数据的相关性

可以看到,三维比二维多了一个维度,0维度。事实上,三维数据中的 12 维度,分别对应二维数据的 01维度,而三维数据中的 0 维度,含义就是 有多少个二维数据
比如 :4x3x2 含义就是 43x2矩阵

3 . 逐一验证

  1. dim取值为0的时候
    例如,如果有两个三维张量A和B,每个张量都包含2个2x2的矩阵,我们可以按照第0维度(即沿着深度方向)将它们拼接在一起:
import torch

A = torch.tensor([[[1, 2], [3, 4]],
                  [[5, 6], [7, 8]]])

B = torch.tensor([[[9, 10], [11, 12]],
                  [[13, 14], [15, 16]]])

C = torch.cat((A, B), dim=0)
print(C.shape)  # 输出:torch.Size([4, 3, 2])

可以看到,输出的结果 是不是 将 A和B的 2个矩阵拼接,就是42x2的矩阵,输出的结果也就是 4x2x2

  1. dim取值为 1 的时候

上面有说到:三维数据中的 12 维度,分别对应二维数据的 01维度,而三维数据中的 0 维度,含义就是 有多少个二维数据

C = torch.cat((A, B), dim=1)

#沿着维度1拼接的结果:
 tensor([[[ 1,  2],
         [ 3,  4],
         [ 9, 10],
         [11, 12]],

        [[ 5,  6],
         [ 7,  8],
         [13, 14],
         [15, 16]]])
#沿着维度1拼接的结果的形状: torch.Size([2, 4, 2])

看到结果是不是验证了之前的说法,抛开0维度,当dim取值1时候。相当于二维数据中dim取值为0时候,也就是 当 dim = 0时候, 按行拼接(或者叫列合并),此时 输出的结果中应当包含拼接数据所有行

  1. dim取值为 2 的时候
# 沿着维度2拼接
C2 = torch.cat((A, B), dim=2)
print("沿着维度2拼接的结果:\n", C2)
print("沿着维度2拼接的结果的形状:", C2.shape)

#沿着维度2拼接的结果:
 tensor([[[ 1,  2,  9, 10],
         [ 3,  4, 11, 12]],

        [[ 5,  6, 13, 14],
         [ 7,  8, 15, 16]]])
#沿着维度2拼接的结果的形状: torch.Size([2, 2, 4])

看结果:抛开0维度,当dim取值2时候。相当于二维数据中dim取值为1时候,也就是:当 dim = 1时候, 按列拼接(或者叫行合并),此时 输出的结果中应当包含拼接数据所有列

你可能感兴趣的:(一周掌握PyTorch,python,pytorch,深度学习)