torch.cat(Tuple[Tensor],dim)->Tensor
输入为Tensor的List/Tuple,输出为一个Tensor
torch.cat()用于对张量的拼接,与数组拼接函数torch.stack()用法类似,二者区别在于输入的变量是数组还是张量。
其中初学者最费解的就是dim的选取,dim的取值范围由输入张量的维度决定,输入为n维张量,dim取值在[0,n-1],接下来我们以实验理解dim不同取值对应的不同操作结果。
初次接触众多博客对dim的讲解为,对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。
先从一个简单的例子入手,输入两个张量为二维,dim取值分别为0和1 :
import torch
X=torch.tensor([[1,2,3],[4,5,6]])
Y=torch.tensor([[7,8,9],[1,4,7]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
print("X:{}\nY:{}\ndim0:{}\ndim1:{}".format(X,Y,A,B))
结果如下
可以看出对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。
import torch
X=torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
Y=torch.tensor([[[7,6],[5,4]],[[8,9],[9,10]]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
C=torch.cat(input,dim=2)
print("X:{}\nY:{}\ndim0:{}\ndim1:{}\ndim2:{}".format(X,Y,A,B,C))
输入为两个三维张量:
输出:
可见对于dim=0,其输出结果为对两个张量的最高维度包含的内容进行拼接,此例中,X和Y均为三维张量,其最高维度包含的内容为二维,因此,dim=0结果是对其二维张量进行拼接组成的三维张量:
那么对于dim=1的情况,就是对次高维包含内容进行拼接,次高维为2维,其内容为1维,将1维进行拼接得到:
以此类推,对于dim=n-1的情况比较难理解,此例dim=2,对次次高维即1维的内容进行拼接,其中1维的内容是0维,可以理解为1维张量括号内的元素,即每个数字,将其进行拼接,得到结果:
至此,torch.cat()的dim作用已经讲清楚,建议动手实验一下就可以弄明白其中的奥秘!!!