PyTorch 中的 repeat() 函数

假设 x = tensor([0, 4, 8])y = tensor([0, 4, 8, 12]) ,则:

1. tensor.repeat(a0)
xx = x.repeat(len(y))

的结果是:

xx = tensor([0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8])

定义向右是 “行的方向”,向下是 “列的方向”,则 tensor.repeat(a0) 可以看做是在行的方向上将原 tensor 复制了 a0 次,并且复制是在原 tensor 内部复制的(即复制出来的内容和原内容是在一个中括号里的)。

2. tensor.repeat(a1, a0)
yy = y.view(-1, 1).repeat(1, len(x))

的结果是:

yy = tensor([[ 0,  0,  0],
             [ 4,  4,  4],
             [ 8,  8,  8],
             [12, 12, 12]])

这里要注意复制的先后顺序,先在行的方向上复制 a0 次,再在列的方向上复制 a1 次。

3. tensor.repeat(an, ..., a1, a0)

先从 repeat() 函数最右边的参数 a0 开始复制,然后依次往左,复制 a1, a2, …, 直至最终到达 an —— 即先从最内层开始复制,最后复制最外层。

你可能感兴趣的:(PyTorch)