torch.repeat_interleave
是 PyTorch 库中的一个函数,它用于重复张量中的元素。这个函数可以沿着指定的维度重复张量中的每个元素,返回一个新的张量。当不指定维度时,会将输入张量展平,并重复每个元素。这个函数在处理序列数据或生成数据增强样本时非常有用。
语法:
torch.repeat_interleave(input, repeats, dim=None) → Tensor
参数:
input
(torch.Tensor): 输入张量。repeats
(int 或 torch.Tensor): 每个元素的重复次数。如果 repeats
是一个整数,则所有元素都将重复相同的次数;如果是一个张量,则需要与 input
张量的形状相同,并且会被广播以适应输入张量的维度。dim
(int, 可选): 重复操作的维度。如果不指定 (None
),则默认将整个张量视为一维。返回值:
dim
的大小会根据重复次数进行调整。import torch
x = torch.tensor([1, 2, 3])
result = torch.repeat_interleave(x, 2)
print(result.shape)
print(result)
输出:
torch.Size([6])
tensor([1, 1, 2, 2, 3, 3])
import torch
y = torch.tensor([[1, 2], [3, 4]])
result = torch.repeat_interleave(y, 3, dim=1)
print(result.shape)
print(result)
输出:
torch.Size([2, 6])
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
import torch
y = torch.tensor([[1, 2], [3, 4]])
repeats_per_row = torch.tensor([2, 3])
result = torch.repeat_interleave(y, repeats_per_row, dim=0)
print(result.shape)
print(result)
输出:
torch.Size([5, 2])
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4],
[3, 4]])
repeats
是一个张量,它必须是一维的,并且其长度必须与 input
张量在 dim
维度上的大小相同 。dim
参数未指定时,repeats
必须是一个整数,不能是一个数组 。dim
维度以外的其他维度上具有相同的形状 。