根据torch的官方文档,torch和numpy中tile和repeat的对应关系如下
torch | numpy | |
---|---|---|
torch.Tensor.repeat | <=> | np.tile |
torch.Tensor.repeat_interleave | <=> | np.repeat |
关于这个操作的具体解释如下
torch.Tensor.repeat(*sizes) → Tensor
这个函数用来沿着多个维度重复原来的Tensor。
*代表传入的sizes不止一个参数,而是多个参数组成的tuple。如果size=(4,2), 则x.repeat(4,2)代表在dim0重复4次,在dim1重复2次。
举例:
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
一般来讲,我们看这种用于repeat的形状(4,2),可以从后往前倒过来看,就是说tensor[1,2,3]现在横向上重复2次,再在列向上重复4次。这样就容易知道repeat得到的tensor的形状大致如何。
同理,x.repeat(4, 2, 1)就是再三个维度上分别重复1,2,4次。
numpy.tile(A, reps)
这个函数同样用来沿着多个维度重复原来的Tensor。如果reps=(4,2), 则np.tile(tensor,(4,2))代表tensor在aixs0上重复4次,在axis1上重复2次。
tile是瓦片、铺瓦的意思,意思是向瓦片一样从各个维度按规定的次数重复铺开来。其与torch.Tensor.repeat()用法一样。两者不同的是repeat封装为Tensor的函数,而tile是直接封装为numpy中的函数,再将tensorA传入进去。
例子:
>>>np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2]])
interleave:交错
torch.Tensor.repeat_interleave(repeats, dim=None) → Tensor
这个函数主要用于重复指定的一个维度,如果dim=None,该函数会将传入的tensor先扁平化为维度为1的数组,每个元素重复repeats次之后,返回扁平化的输出数组。
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
[3, 4],
[3, 4]])
numpy.repeat(a, repeats, axis=None)
这个函数主要用于重复指定的一个维度,如果axis=None,该函数会将传入的tensor先扁平化为维度为1的数组,每个元素重复repeats次之后,返回扁平化的输出数组。
>>>x = np.array([[1,2],[3,4]])
>>>np.repeat(x, 2)
array([1, 1, 2, 2, 3, 3, 4, 4])
>>>np.repeat(x, 3, axis=1)
array([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>>np.repeat(x, [1, 2], axis=0)
array([[1, 2],
[3, 4],
[3, 4]])
【1】https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html#torch.repeat_interleave
【2】https://numpy.org/doc/stable/reference/generated/numpy.repeat.html?highlight=repeat#numpy.repeat