PyTorch常用张量切割和拼接方法(torch.chunk、torch.split、torch.cat和torch.stack用法详解)

1、torch.chunk

首先看官方文档的解释:

功能:尝试将张量拆分为指定数量的数据块,每个数据块都是输入张量的一个视图。

例子:

>>> torch.arange(5).chunk(3)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4]))

>>> torch.arange(6).chunk(3)
(tensor([0, 1]),
 tensor([2, 3]),
 tensor([4, 5]))

>>> torch.arange(7).chunk(3)
(tensor([0, 1, 2]),
 tensor([3, 4, 5]),
 tensor([6]))

         通过例子很好理解,torch.chunk是将tensor切割为k个tensor。

2、torch.split

首先看官方文档的解释:

功能:将张量分成数据块,每个数据块都是原始张量的视图。

官方解释的功能与torch.chunk大致相同,具体的区别通过例子很容易理解。

例子:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))

>>> torch.split(a, [1,2])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

        如果split_size_or_sections传入的是整数的话,就是沿着dim方向尽量分割出长度为split_size_or_sections的tensor。

        如果split_size_or_sections传入的是list的话,就按照list的参数分割不同长度的tensor。

注:chunk是指定分割数量,split是指定分割完的tensor的样式。

3、torch.cat

首先看官方文档的解释:

PyTorch常用张量切割和拼接方法(torch.chunk、torch.split、torch.cat和torch.stack用法详解)_第1张图片

功能:在给定维度上对输入的张量序列seq 进行连接操作,所有的张量必须有相同的形状(连接维度除外)或为空。

           torch.cat()可以被看作是torch.split()和torch.chunk()的逆向操作。

例子:

>>> a = torch.arange(6).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])

>>> torch.cat((a, a, a), 0)
tensor([[0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5]])

>>> torch.cat((a, a, a), 1)
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5]])

        通过例子很好理解,torch.cat将tensor沿着dim进行拼接。

4、torch.stack

功能:沿着一个新的维度连接一系列张量。所有张量需要是相同的大小。

例子:

>>> a = torch.arange(6).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> torch.stack((a, a), 0) # torch.Size([2, 3, 3])
tensor([[[0, 1, 2],
         [0, 1, 2]],

        [[3, 4, 5],
         [3, 4, 5]],

        [[6, 7, 8],
         [6, 7, 8]]])

>>> torch.stack((a, a), 1) # torch.Size([3, 2, 3])
tensor([[[0, 1, 2],
         [0, 1, 2]],

        [[3, 4, 5],
         [3, 4, 5]]])

>>> torch.stack((a, a), 2) # torch.Size([3, 3, 2])
tensor([[[0, 0],
         [1, 1],
         [2, 2]],

        [[3, 3],
         [4, 4],
         [5, 5]],

        [[6, 6],
         [7, 7],
         [8, 8]]])

        与cat相比不同的是,stack是在添加一个新的维度去连接,连接方式通过例子中的output和Size很容易理解。 

你可能感兴趣的:(pytorch,深度学习,神经网络)