PyTorch库学习之torch.repeat_interleave函数

PyTorch库学习之torch.repeat_interleave函数

一、简介

torch.repeat_interleave 是 PyTorch 库中的一个函数,它用于重复张量中的元素。这个函数可以沿着指定的维度重复张量中的每个元素,返回一个新的张量。当不指定维度时,会将输入张量展平,并重复每个元素。这个函数在处理序列数据或生成数据增强样本时非常有用。

二、语法和参数

语法:

torch.repeat_interleave(input, repeats, dim=None) → Tensor

参数:

  • input (torch.Tensor): 输入张量。
  • repeats (int 或 torch.Tensor): 每个元素的重复次数。如果 repeats 是一个整数,则所有元素都将重复相同的次数;如果是一个张量,则需要与 input 张量的形状相同,并且会被广播以适应输入张量的维度。
  • dim (int, 可选): 重复操作的维度。如果不指定 (None),则默认将整个张量视为一维。

返回值:

  • 返回一个新的张量,其形状与输入张量相同,但沿给定维度 dim 的大小会根据重复次数进行调整。

三、实例

3.1 重复一维张量中的每个元素
import torch
x = torch.tensor([1, 2, 3])
result = torch.repeat_interleave(x, 2)

print(result.shape)
print(result)

输出:

torch.Size([6])
tensor([1, 1, 2, 2, 3, 3])
3.2 沿着指定维度重复二维张量的元素
import torch

y = torch.tensor([[1, 2], [3, 4]])
result = torch.repeat_interleave(y, 3, dim=1)

print(result.shape)
print(result)

输出:

torch.Size([2, 6])
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
3.3 使用不同重复次数重复二维张量的行
import torch

y = torch.tensor([[1, 2], [3, 4]])
repeats_per_row = torch.tensor([2, 3])
result = torch.repeat_interleave(y, repeats_per_row, dim=0)

print(result.shape)
print(result)

输出:

torch.Size([5, 2])
tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [3, 4]])

四、注意事项

  • 如果 repeats 是一个张量,它必须是一维的,并且其长度必须与 input 张量在 dim 维度上的大小相同 。
  • dim 参数未指定时,repeats 必须是一个整数,不能是一个数组 。
  • 返回的张量与输入张量在除了 dim 维度以外的其他维度上具有相同的形状 。

你可能感兴趣的:(#,torch,pytorch,学习,人工智能)