作者:机器视觉全栈er
网站:cvtutorials.com
合和分是相反的两个过程。
torch.cat: cat是英文单词concatenate的缩写,表示连接的意思。torch.cat将一系列的tensor连接起来,这里要求被连接的tensor要有同样的shape,这个函数的用法如下:
torch.cat(tensors, dim=0, *, out=None)
这里的tensors的常见形式为(x1, x2, x3, …),其中每个都是一个tensor,dim表示多个tensor沿着哪个维度进行连接,默认是0,表示沿着第0维连接,我们以二阶tensor(矩阵)为例:
>>> cvtutorials = torch.randn(2, 2)
>>> cvtutorials
tensor([[ 1.0078, 0.6374],
[-1.2574, -1.2475]])
>>> torch.cat((cvtutorials, cvtutorials), 0)
tensor([[ 1.0078, 0.6374],
[-1.2574, -1.2475],
[ 1.0078, 0.6374],
[-1.2574, -1.2475]])
>>> torch.cat((cvtutorials, cvtutorials), 1)
tensor([[ 1.0078, 0.6374, 1.0078, 0.6374],
[-1.2574, -1.2475, -1.2574, -1.2475]])
从中可以看出,沿着矩阵的第0维拼接,指的是增加矩阵的行数,沿着矩阵的第1维拼接,指的是增加矩阵的列数。
torch.column_stack: 将多个tensor沿着水平方向进行堆叠,即按列拼接,用法如下:
torch.column_stack(tensors, *, out=None)
举个简单的例子:
>>> a = torch.arange(1, 5)
>>> a
tensor([1, 2, 3, 4])
>>> cvtutorials = torch.arange(6, 10)
>>> cvtutorials
tensor([6, 7, 8, 9])
>>> torch.column_stack((a, cvtutorials))
tensor([[1, 6],
[2, 7],
[3, 8],
[4, 9]])
torch.dstack, torch.vstack, torch.hstack:分别是按照深度、垂直和水平三个方向进行堆叠,用法如下:
torch.dstack(tensors, *, out=None)
torch.hstack(tensors, *, out=None)
torch.vstack(tensors, *, out=None)
对上面三个函数用些例子进行说明比较,如下:
>>> import torch
>>> cvtutorials1 = torch.randn(1, 3)
>>> cvtutorials2 = torch.randn(3, 1)
>>> cvtutorials1
tensor([[-1.8013, 2.4992, -0.2126]])
>>> cvtutorials2
tensor([[-1.4519],
[ 0.2338],
[-1.8889]])
>>> torch.dstack((cvtutorials1, cvtutorials1))
tensor([[[-1.8013, -1.8013],
[ 2.4992, 2.4992],
[-0.2126, -0.2126]]])
>>> torch.vstack((cvtutorials1, cvtutorials1))
tensor([[-1.8013, 2.4992, -0.2126],
[-1.8013, 2.4992, -0.2126]])
>>> torch.hstack((cvtutorials1, cvtutorials1))
tensor([[-1.8013, 2.4992, -0.2126, -1.8013, 2.4992, -0.2126]])
>>> torch.dstack((cvtutorials2, cvtutorials2))
tensor([[[-1.4519, -1.4519]],
[[ 0.2338, 0.2338]],
[[-1.8889, -1.8889]]])
>>> torch.vstack((cvtutorials2, cvtutorials2))
tensor([[-1.4519],
[ 0.2338],
[-1.8889],
[-1.4519],
[ 0.2338],
[-1.8889]])
>>> torch.hstack((cvtutorials2, cvtutorials2))
tensor([[-1.4519, -1.4519],
[ 0.2338, 0.2338],
[-1.8889, -1.8889]])