【python】repeat_interleave()

功能:重复张量的元素 

torch.repeat_interleave(input, repeats, dim=None)

input :输入张量
repeats:每个元素的重复次数。
dim:需要重复的维度。默认输出(flatten)扁平化张量

 一维:

a = torch.tensor([1, 2, 3, 4])

a.repeat_interleave(2) # 等同于:torch.repeat_interleave(a, 2)

#结果:tensor([1, 1, 2, 2, 3, 3, 4, 4])

二维:

b = torch.tensor([[1, 2], [3, 4]])
b.repeat_interleave(2) # 等同于:torch.repeat_interleave(b, 2)

# 结果:tensor([1, 1, 2, 2, 3, 3, 4, 4])

不同元素重复不同次数:

b = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(b, torch.tensor([1, 2]), 0)
# 结果:tensor([[1, 2],
               [3, 4],
               [3, 4]])

torch.repeat_interleave(b, torch.tensor([1, 2]), 1)
# 结果:tensor([[1, 2, 2],
               [3, 4, 4]])

你可能感兴趣的:(Python,python,开发语言,后端)