pytorch的python API略读--tensor(五)

作者:机器视觉全栈er
网站:cvtutorials.com

2.1.4 合

合和分是相反的两个过程。

torch.cat: cat是英文单词concatenate的缩写,表示连接的意思。torch.cat将一系列的tensor连接起来,这里要求被连接的tensor要有同样的shape,这个函数的用法如下:

torch.cat(tensors, dim=0, *, out=None)

这里的tensors的常见形式为(x1, x2, x3, …),其中每个都是一个tensor,dim表示多个tensor沿着哪个维度进行连接,默认是0,表示沿着第0维连接,我们以二阶tensor(矩阵)为例:

>>> cvtutorials = torch.randn(2, 2)
>>> cvtutorials
tensor([[ 1.0078,  0.6374],
        [-1.2574, -1.2475]])
>>> torch.cat((cvtutorials, cvtutorials), 0)
tensor([[ 1.0078,  0.6374],
        [-1.2574, -1.2475],
        [ 1.0078,  0.6374],
        [-1.2574, -1.2475]])
>>> torch.cat((cvtutorials, cvtutorials), 1)
tensor([[ 1.0078,  0.6374,  1.0078,  0.6374],
        [-1.2574, -1.2475, -1.2574, -1.2475]])

从中可以看出,沿着矩阵的第0维拼接,指的是增加矩阵的行数,沿着矩阵的第1维拼接,指的是增加矩阵的列数。

torch.column_stack: 将多个tensor沿着水平方向进行堆叠,即按列拼接,用法如下:

torch.column_stack(tensors, *, out=None)

举个简单的例子:

>>> a = torch.arange(1, 5)
>>> a
tensor([1, 2, 3, 4])
>>> cvtutorials = torch.arange(6, 10)
>>> cvtutorials
tensor([6, 7, 8, 9])
>>> torch.column_stack((a, cvtutorials))
tensor([[1, 6],
        [2, 7],
        [3, 8],
        [4, 9]])

torch.dstack, torch.vstack, torch.hstack:分别是按照深度、垂直和水平三个方向进行堆叠,用法如下:

torch.dstack(tensors, *, out=None)
torch.hstack(tensors, *, out=None)
torch.vstack(tensors, *, out=None)

对上面三个函数用些例子进行说明比较,如下:

>>> import torch
>>> cvtutorials1 = torch.randn(1, 3)
>>> cvtutorials2 = torch.randn(3, 1)
>>> cvtutorials1
tensor([[-1.8013,  2.4992, -0.2126]])
>>> cvtutorials2
tensor([[-1.4519],
        [ 0.2338],
        [-1.8889]])

>>> torch.dstack((cvtutorials1, cvtutorials1))
tensor([[[-1.8013, -1.8013],
         [ 2.4992,  2.4992],
         [-0.2126, -0.2126]]])
>>> torch.vstack((cvtutorials1, cvtutorials1))
tensor([[-1.8013,  2.4992, -0.2126],
        [-1.8013,  2.4992, -0.2126]])
>>> torch.hstack((cvtutorials1, cvtutorials1))
tensor([[-1.8013,  2.4992, -0.2126, -1.8013,  2.4992, -0.2126]])


>>> torch.dstack((cvtutorials2, cvtutorials2))
tensor([[[-1.4519, -1.4519]],
        [[ 0.2338,  0.2338]],
        [[-1.8889, -1.8889]]])
>>> torch.vstack((cvtutorials2, cvtutorials2))
tensor([[-1.4519],
        [ 0.2338],
        [-1.8889],
        [-1.4519],
        [ 0.2338],
        [-1.8889]])
>>> torch.hstack((cvtutorials2, cvtutorials2))
tensor([[-1.4519, -1.4519],
        [ 0.2338,  0.2338],
        [-1.8889, -1.8889]])

你可能感兴趣的:(pytorch的python,API,python,人工智能,计算机视觉)