torch.Tensor.repeat是什么操作?对比numpy库和torch库中的tile和repeat操作

一、前言

根据torch的官方文档,torch和numpy中tile和repeat的对应关系如下

torch numpy
torch.Tensor.repeat <=> np.tile
torch.Tensor.repeat_interleave <=> np.repeat

关于这个操作的具体解释如下

二 torch.Tensor.repeat()和np.tile()

torch.Tensor.repeat(*sizes) → Tensor

这个函数用来沿着多个维度重复原来的Tensor。
*代表传入的sizes不止一个参数,而是多个参数组成的tuple。如果size=(4,2), 则x.repeat(4,2)代表在dim0重复4次,在dim1重复2次。

举例:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

一般来讲,我们看这种用于repeat的形状(4,2),可以从后往前倒过来看,就是说tensor[1,2,3]现在横向上重复2次,再在列向上重复4次。这样就容易知道repeat得到的tensor的形状大致如何。

同理,x.repeat(4, 2, 1)就是再三个维度上分别重复1,2,4次。

numpy.tile(A, reps)

这个函数同样用来沿着多个维度重复原来的Tensor。如果reps=(4,2), 则np.tile(tensor,(4,2))代表tensor在aixs0上重复4次,在axis1上重复2次。
tile是瓦片、铺瓦的意思,意思是向瓦片一样从各个维度按规定的次数重复铺开来。其与torch.Tensor.repeat()用法一样。两者不同的是repeat封装为Tensor的函数,而tile是直接封装为numpy中的函数,再将tensorA传入进去。

例子:

>>>np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],
       [0, 1, 2, 0, 1, 2]])

三、torch.Tensor.repeat_interleave和np.repeat

interleave:交错

torch.Tensor.repeat_interleave(repeats, dim=None) → Tensor

这个函数主要用于重复指定的一个维度,如果dim=None,该函数会将传入的tensor先扁平化为维度为1的数组,每个元素重复repeats次之后,返回扁平化的输出数组。

>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

numpy.repeat(a, repeats, axis=None)

这个函数主要用于重复指定的一个维度,如果axis=None,该函数会将传入的tensor先扁平化为维度为1的数组,每个元素重复repeats次之后,返回扁平化的输出数组。

>>>x = np.array([[1,2],[3,4]])
>>>np.repeat(x, 2)
array([1, 1, 2, 2, 3, 3, 4, 4])
>>>np.repeat(x, 3, axis=1)
array([[1, 1, 1, 2, 2, 2],
       [3, 3, 3, 4, 4, 4]])
>>>np.repeat(x, [1, 2], axis=0)
array([[1, 2],
       [3, 4],
       [3, 4]])

四、参考资料

【1】https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave
【2】https://numpy.org/doc/stable/reference/generated/numpy.repeat.html?highlight=repeat#numpy.repeat

你可能感兴趣的:(pytorch,numpy,深度学习,人工智能)