首先看官方文档的解释:
功能:尝试将张量拆分为指定数量的数据块,每个数据块都是输入张量的一个视图。
例子:
>>> torch.arange(5).chunk(3)
(tensor([0, 1]),
tensor([2, 3]),
tensor([4]))
>>> torch.arange(6).chunk(3)
(tensor([0, 1]),
tensor([2, 3]),
tensor([4, 5]))
>>> torch.arange(7).chunk(3)
(tensor([0, 1, 2]),
tensor([3, 4, 5]),
tensor([6]))
通过例子很好理解,torch.chunk是将tensor切割为k个tensor。
首先看官方文档的解释:
功能:将张量分成数据块,每个数据块都是原始张量的视图。
官方解释的功能与torch.chunk大致相同,具体的区别通过例子很容易理解。
例子:
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,2])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
如果split_size_or_sections传入的是整数的话,就是沿着dim方向尽量分割出长度为split_size_or_sections的tensor。
如果split_size_or_sections传入的是list的话,就按照list的参数分割不同长度的tensor。
注:chunk是指定分割数量,split是指定分割完的tensor的样式。
首先看官方文档的解释:
功能:在给定维度上对输入的张量序列seq
进行连接操作,所有的张量必须有相同的形状(连接维度除外)或为空。
torch.cat()可以被看作是torch.split()和torch.chunk()的逆向操作。
例子:
>>> a = torch.arange(6).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
>>> torch.cat((a, a, a), 0)
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
>>> torch.cat((a, a, a), 1)
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5]])
通过例子很好理解,torch.cat将tensor沿着dim进行拼接。
功能:沿着一个新的维度连接一系列张量。所有张量需要是相同的大小。
例子:
>>> a = torch.arange(6).reshape(2, 3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> torch.stack((a, a), 0) # torch.Size([2, 3, 3])
tensor([[[0, 1, 2],
[0, 1, 2]],
[[3, 4, 5],
[3, 4, 5]],
[[6, 7, 8],
[6, 7, 8]]])
>>> torch.stack((a, a), 1) # torch.Size([3, 2, 3])
tensor([[[0, 1, 2],
[0, 1, 2]],
[[3, 4, 5],
[3, 4, 5]]])
>>> torch.stack((a, a), 2) # torch.Size([3, 3, 2])
tensor([[[0, 0],
[1, 1],
[2, 2]],
[[3, 3],
[4, 4],
[5, 5]],
[[6, 6],
[7, 7],
[8, 8]]])
与cat相比不同的是,stack是在添加一个新的维度去连接,连接方式通过例子中的output和Size很容易理解。