2020-10-20

Torch.repeat函数理解

一、官方定义

官方文档定义如下:
repeat(*sizes) → Tensor
Repeats this tensor along the specified dimensions.

Unlike expand(), this function copies the tensor’s data.

WARNING

repeat() behaves differently from numpy.repeat, but is more similar to numpy.tile. For the operator similar to numpy.repeat, see torch.repeat_interleave().

Parameters
sizes (torch.Size or int…) – The number of times to repeat this tensor along each dimension

Example:

x = torch.tensor([1, 2, 3])
x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])

x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

二、理解

官方文档中提到与expand不同的是,expand函数仅会改变tensor的视图,而repeat会拷贝原tensor的数据。
从几个简单的例子能帮助更好的理解。

import torch
In [1]: import torch                                                                                                                                                                                                                                                 

In [2]: a = torch.randint(5,(2,3))                                                                                                                                                                                                                                   

In [3]: a.repeat(1,1)                                                                                                                                                                                                                                                
Out[3]: 
tensor([[2, 1, 3],
        [4, 3, 1]])#在各自维度复制一个,则保持不变。
In [4]: a.repeat(1,1,1)                                                                                                                                                                                                                                              
Out[4]: 
tensor([[[2, 1, 3],
         [4, 3, 1]]])#加了一个维度。
In [5]: a.repeat(1,2,3)                                                                                                                                                                                                                                              
Out[5]: 
tensor([[[2, 1, 3, 2, 1, 3, 2, 1, 3],
         [4, 3, 1, 4, 3, 1, 4, 3, 1],
         [2, 1, 3, 2, 1, 3, 2, 1, 3],
         [4, 3, 1, 4, 3, 1, 4, 3, 1]]])
         #在1维度复制1,2维度复制2,3维度3个       

最后一个理解可以这样,最后输出的张量第一个维度为1保持。
第二个维度需要复制两个,第三个维度复制三个。按照numpy.tile来理解就是在维度上堆砌相同的张量,从最后一个维度来看 横向堆砌三个张量,原张量本来应该为(2,3)所以现在应该是(2,9),横向堆砌两个,所以维度变为(4,9),最终输出为(1,4,9)。
再举一个例子帮助理解。

In [8]: a.repeat(4,2,1)                                                                                                                                                                                                                                              
Out[8]: 
tensor([[[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]],

        [[2, 1, 3],
         [4, 3, 1],
         [2, 1, 3],
         [4, 3, 1]]])#输出形状为(4,4,3)

你可能感兴趣的:(pytorch,深度学习)