Pytorch中的repeat()函数

pytorch中的repeat()函数可以对张量进行复制。

当参数只有一个时,参数表示在列方向上复制的次数
当参数只有两个时,第一个参数表示在行方向上复制的次数,第二个参数表示在列方向上复制的次数。
当参数有三个时,第一个参数表示在通道数方向上复制的次数,第二个参数表示在行方向上复制的次数,第三个参数表示在列方向上复制的次数。

接下来我们举一个例子来直观理解一下:

>>> x = torch.tensor([6,7,8])
>>> x.repeat(4)
tensor([[6, 7, 8, 6, 7, 8, 6, 7, 8, 6, 7, 8]])


>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,2)
tensor([[6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8]])
>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,1)
tensor([[6, 7, 8],
        [6, 7, 8],
        [6, 7, 8],
        [6, 7, 8]]) 
      
>>> x.repeat(4,2,1)
tensor([[[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]]])
>>> x.repeat(4,2,1).size()
torch.Size([4, 2, 3])

你可能感兴趣的:(pytorch,CV,python,pytorch,深度学习,python)