repeat()和expand()函数详解

torch.repeat()

  • 定义: repeat() 方法对张量的元素沿着指定的维度进行重复。

  • 参数:

    • *sizes (torch.Size 或 int...):一系列的整数,定义了每个维度上重复的次数。
  • 返回值: Tensor。一个新的张量,是原始张量沿着各个维度重复后的结果。

  • 用途: 使用repeat()方法可以创建重复元素的新张量,用于各种批处理或数据增强操作。

  • 代码示例:

    x = torch.tensor([1, 2, 3])
    x.repeat(4, 2)
    # 输出: tensor([[1, 2, 3, 1, 2, 3],
    #              [1, 2, 3, 1, 2, 3],
    #              [1, 2, 3, 1, 2, 3],
    #              [1, 2, 3, 1, 2, 3]])
    

    orch.expand()

  • 定义: expand() 方法返回一个新的视图,它将张量的大小扩展到更大的尺寸。

  • 参数:

    • *sizes (torch.Size 或 int...):扩展后的张量尺寸。
  • 返回值: Tensor。一个新的视图,它在不复制数据的情况下呈现了更大尺寸的张量。

  • 用途: expand() 方法常用于将一个小尺寸张量扩展为更大尺寸以进行广播操作,特别是在矩阵运算或批处理中。

  • 代码示例:

    x = torch.tensor([[1], [2], [3]])
    x.expand(-1, 3)
    # 输出: tensor([[1, 1, 1],
    #              [2, 2, 2],
    #              [3, 3, 3]])
    

你可能感兴趣的:(PyTorch,深度学习,python,人工智能)