[Pytorch] 详解 torch.cat()

1. 定义

官方手册中描述为:

torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列seq 进行连接操作。

torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数可以通过下面例子更好的理解。

参数:

  • inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
  • dimension (int, optional) – 沿着此维连接张量序列。

2. 例子

>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.cat((x, x, x), 0)

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]

>>> torch.cat((x, x, x), 1)

 0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735
 
[torch.FloatTensor of size 2x9]

torch.cat((x, x, x), 1)中的 0 or 1 就是指示的维度。

除此之外,可以指示为-1。

我将举几个例子

如图,a是2x3 b是2x5的一个张量

[Pytorch] 详解 torch.cat()_第1张图片
拼接后:
[Pytorch] 详解 torch.cat()_第2张图片

一句话总结:上下拼接要列数相同,左右拼接要行数相同。

另,用torch.cat拼接list里的tensor:

先整个list:
[Pytorch] 详解 torch.cat()_第3张图片

在这里插入图片描述
可以清楚的看到已经拼接好了,即参数可以直接传入一个seq

你可能感兴趣的:(Python,AI,深度学习,python,人工智能,pytorch,torch)