【Pytorch】repeat()和expand()区别

torch.Tensor是包含一种数据类型元素的多维矩阵。

torch.Tensor is a multi-dimensional matrix containing elements of a single data type.

torch.Tensor有两个实例方法可以用来扩展某维的数据的尺寸,分别是repeat()expand()

expand()

expand(*sizes) -> Tensor
*sizes(torch.Size or int) - the desired expanded  size
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.

返回当前张量在某维扩展更大后的张量。

例子:

import torch
x = torch.tensor([1, 2, 3])
y = x.expand(2, 3)
print(y)
结果为:
tensor([[1, 2, 3],
        [1, 2, 3]])

print(x)
结果为:
tensor([1, 2, 3])

就拿上面的例子来说 x的形状为3,他被扩展成2行3列,该扩展形状即为最终形状。

那么此时x会自动在高位添加1这个空维度,这时x会变为1*3的形状,随后使用复制的方式,将形状变为2*3。

>> x = torch.randn(2, 1, 1, 4)
>> x.expand(-1, 2, 3, -1)
torch.Size([2, 2, 3, 4])

正如上面的这个例子,不难看出,都是沿着1所在的维度进行复制。 

repeat()

repeat(*sizes) -> Tensor
*size(torch.Size or int) - The  number of times to repeat this tensor along each dimension.
Repeats this tensor along the specified dimensions.

沿着特定的维度重复这个张量。

例子:

import torch

>> x = torch.tensor([1, 2, 3])
>> x.repeat(3, 2)
tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3]])

上面的x形状为3,由于repeat后面的维度是2个维度,因此x也需要变成2个维度,即为1*3。

接下来repeat即从x最右边的3这个维度开始,x的3所在的维度被重复了2次,此时x变成[1,2,3,1,2,3];

然后看x的1这个维度,被重复了3次,变成[[1,2,3,1,2,3],[1,2,3,1,2,3][1,2,3,1,2,3]]。

>> x2 = torch.randn(2, 3, 4)
>> x2.repeat(2, 1, 3).shape
torch.Tensor([4, 3, 12])

假设x2的矩阵为[[[a1,a2,a3,a4],[a1,a2,a3,a4],[a1,a2,a3,a4]],[[a1,a2,a3,a4],[a1,a2,a3,a4],[a1,a2,a3,a4]]]

上面x2的形状为2*3*4,由于repeat的维度为3,所以x2的维度中不需要补1,可以直接进行重复操作。

首先从x2的最后边的4这个维度开始,被重复了3次,变为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

然后看x2的3这个维度,被重复了1次,其矩阵没有变化,依旧为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

最后看x2的2这个维度,被重复了2次,这时变为

[[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]],

[[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4],[a1,a2,a3,a4,a1,a2,a3,a4,a1,a2,a3,a4]]];

转自:https://zhuanlan.zhihu.com/p/58109107

你可能感兴趣的:(Pytorch)