Pytorch的repeat()方法再深度学习中经常用到,用于复制tensor,最好的说明当然是官方文档。
repeat的用法说明很简单:重复每个张量的维度的次数。
-这里有个warrning很有意思,意思是Pytorch的repeat和numpy.repeat是不太一样的。下次填坑。
import torch
x = torch.tensor([1, 2, 3])
print(x.shape)
# torch.Size([3])
print(x.repeat(4, 2))
"""
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])"""
x.repeat(4, 2, 1).size()
#torch.Size([4, 2, 3])
x是一维的tensor,但传入repeat的size是二维的即(4,2)维度时不对应的,看一下复制流程,x是一维tensor,但是可以看成是二维的,新增的维度的值为1。举个例子相当于把一个n维向量(行向量)看作一个一行n列的矩阵,向量是一维但矩阵是二维的。
import torch
# 原始x是一维的张量
x = torch.tensor([1, 2, 3])
# 把x的维数增加一维变成二维
x = x.reshape(1,3)
x.repeat(4,2)
# 得到相同的结果
"""
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])"""
repeat(4,2)相当于把整个tensor在行方向上复制4次,在列方向上复制2次。注意是整个tensor,而不是复制完一行接着复制下一行。
# x.shape (2,3)
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 维度对应时在相应的维度复制即可
x.repeat(4, 2)
"""
tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6],
[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
"""
再看官方例子的最后一行代码:
import torch
# 此时x是一维的
x = torch.tensor([1, 2, 3])
# 复制的是三维的
x.repeat(4, 2, 1)
# 和上面例子是一样的,先把x升到3维
x = torch.tensor([1, 2, 3])
# 把x看成是一个1通道1行3列的三维张量
x = x.reshape(1,1,3)
x.repeat(4, 2, 1)
# 对应维度复制即可得到结果
# x变成了4通道2行3列的张量
"""
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])
"""