torch.cat
函数的参数以及参数值torch
中的cat
函数用于沿着指定维度
将张量连接起来。具体而言,如果给定一个包含多个张量的序列,通过指定dim
参数可以将它们沿着指定维度
连接在一起。
函数的常见形式如下:
torch.cat(seq, dim=0, out=None)
其中:
seq
:一个Tensor序列,即要拼接的多个张量
。
dim
:连接的维度,默认为0(按行拼接)。可以是任何整数值,具体取值依赖于输入张量的维度。例如,对于二维张量,dim=0表示按行拼接,dim=1表示按列拼接。
out
:输出张量。如果指定了此参数,则结果会被写入该张量中,不会创建新的张量。如果没有指定,则会创建新的张量作为结果返回。
举个例子,假设我们有两个二维张量A和B:
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
如果要将它们按行拼接
(即在第0维度
上拼接),则可以这样做:
python
C = torch.cat((A, B), dim=0)
print(C)
# 输出:
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
如果要将它们按列拼接
(即在第1维度
上拼接),则可以这样做:
python
C = torch.cat((A, B), dim=1)
print(C)
# 输出:
# tensor([[1, 2, 5, 6],
# [3, 4, 7, 8]])
总之,cat函数可以用于将张量沿着指定的维度连接在一起,非常灵活。需要根据具体情况选择合适的dim参数值来实现多种拼接方式。
按行拼接
(或者叫列合并
)是指将多个二维张量沿着第0维度
(即行
)拼接在一起,形成一个更大的二维张量。例如,假设有两个二维张量A和B:A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
那么将它们按行拼接之后得到的结果就是:
C = [[1, 2],
[3, 4],
[5, 6],
[7, 8]]
可以看到,新的张量C比原来的张量A和B都要长
,因为它包含了两个输入张量中所有的行
。
按列拼接
(或者叫行合并
)是指将多个二维张量沿着第1维度
(即列
)拼接在一起,形成一个更宽的二维张量。例如,假设有两个二维张量A和B:A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
那么将它们按列拼接之后得到的结果就是:
C = [[1, 2, 5, 6],
[3, 4, 7, 8]]
可以看到,新的张量C比原来的张量A和B都要宽,因为它包含了两个输入张量中所有的列
。
dim = 0
时候, 按行拼接
(或者叫列合并
),此时 输出
的结果中应当包含拼接数据
的所有行
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
那么将它们按行拼接之后得到的结果就是:
C = [[1, 2],
[3, 4],
[5, 6],
[7, 8]]
dim = 1
时候, 按列拼接
(或者叫行合并
),此时 输出
的结果中应当包含拼接数据
的所有列
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
那么将它们按列拼接之后得到的结果就是:
C = [[1, 2, 5, 6],
[3, 4, 7, 8]]
dim
取值为(0,1,2),依次类推在三维张量中,dim参数的取值范围为0、1、2,具体的含义如下:
dim=0
:表示沿着第0维度
进行拼接。这意味着将两个包含多个矩阵的三维张量连接起来,形成一个更高的三维张量。
dim=1
:表示沿着第1维度
进行拼接。这意味着将两个包含多个行向量的三维张量连接起来,形成一个更宽的三维张量。
dim=2
:表示沿着第2维度
进行拼接。这意味着将两个包含多个列向量的三维张量连接起来,形成一个更深的三维张量。
可以看到,三维比二维多了一个维度,0
维度。事实上,三维数据中的 1
和 2
维度,分别对应二维数据的 0
和1
维度,而三维数据中的 0
维度,含义就是 有多少个二维数据
,
比如 :4x3x2
含义就是 4
个 3x2
的矩阵
。
dim
取值为0
的时候第0维度
(即沿着深度方向)将它们拼接在一起:import torch
A = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
B = torch.tensor([[[9, 10], [11, 12]],
[[13, 14], [15, 16]]])
C = torch.cat((A, B), dim=0)
print(C.shape) # 输出:torch.Size([4, 3, 2])
可以看到
,输出的结果 是不是 将 A和B的 2个矩阵拼接,就是4
个2x2
的矩阵,输出的结果
也就是 4x2x2
dim
取值为 1
的时候上面有说到:三维数据中的 1
和 2
维度,分别对应二维数据的 0
和1
维度,而三维数据中的 0
维度,含义就是 有多少个二维数据
,
C = torch.cat((A, B), dim=1)
#沿着维度1拼接的结果:
tensor([[[ 1, 2],
[ 3, 4],
[ 9, 10],
[11, 12]],
[[ 5, 6],
[ 7, 8],
[13, 14],
[15, 16]]])
#沿着维度1拼接的结果的形状: torch.Size([2, 4, 2])
看到结果是不是验证了之前的说法,抛开0维度
,当dim取值1
时候。相当于
二维数据中dim取值为0时候,也就是 当 dim = 0
时候, 按行拼接
(或者叫列合并
),此时 输出
的结果中应当包含拼接数据
的所有行
dim
取值为 2
的时候# 沿着维度2拼接
C2 = torch.cat((A, B), dim=2)
print("沿着维度2拼接的结果:\n", C2)
print("沿着维度2拼接的结果的形状:", C2.shape)
#沿着维度2拼接的结果:
tensor([[[ 1, 2, 9, 10],
[ 3, 4, 11, 12]],
[[ 5, 6, 13, 14],
[ 7, 8, 15, 16]]])
#沿着维度2拼接的结果的形状: torch.Size([2, 2, 4])
看结果:抛开0维度
,当dim取值2
时候。相当于
二维数据中dim取值为1
时候,也就是:当 dim = 1
时候, 按列拼接
(或者叫行合并
),此时 输出
的结果中应当包含拼接数据
的所有列