理解pytorch系列:整型索引是怎么实现的

整型索引的匹配规则

在PyTorch中使用整型索引时,需要遵循一些基本规则来确定如何从原始张量中选择数据。整型索引可以是Python中的列表或者数组、NumPy数组,或者是PyTorch的LongTensor。整型索引允许在任何维度上进行复杂的数据选取操作,例如选择特定的行、列或者任意的元素。

以下是PyTorch中整型索引的匹配规则:

  1. 单维度索引:如果你对一个维度使用整型索引(比如通过传递一个整数列表),你将根据列表中的每个整数值得到被索引维度上对应的切片。索引列表中的每个整数指定要选择的数据在该维度上的位置。

    示例:

    import torch
    
    x = torch.arange(12).view(3, 4)
    index = torch.tensor([0, 2])
    selected = x[index]  # 选取第一行和第三行
    
  2. 跨维度的整型索引:如果你使用多个整型索引列表分别对应于多个维度,你将在这些维度上得到一个表格化的选取。每组索引列表定义在对应维度上的位置,交叉点上的元素被选中。

    示例:

    rows = torch.tensor([0, 2])
    columns = torch.tensor([1, 3])
    selected = x[rows, columns]  # (0,1) 和 (2,3) 位置上的元素被选中
    
  3. 张量索引:你也可以使用一个整型张量作为索引。如果索引张量是一个压平(flatten)的一维张量,则选择的是按照索引张量指示的线性索引的元素。如果索引张量是多维的,则返回的张量形状将匹配索引张量的形状。

    示例:

    row_indices = torch.tensor([0, 1, 2])
    col_indices = torch.tensor([1, 1, 1])
    selected = x[row_indices, col_indices]  # 选择三个元素,它们处于不同的行但相同列
    
  4. 广播规则:整型索引同样也受到广播规则的影响。这意味着如果你在不同的维度上使用了不同长度的索引列表,PyTorch会尝试将它们广播到一个共同的形状,然后执行索引操作。

在使用整型索引时,返回的张量总是一个复制,而不是原始数据的视图。这意味着对返回的张量所做的修改不会影响原始张量。在执行整型索引操作时,维度的顺序是非常重要的,因为它们决定了哪些数据将会被选择。

上述内容是PyTorch整型索引的一些基础规则和用例,当然,PyTorch提供的索引能力还包括更高级和复杂的用法,如使用掩码张量或组合不同类型的索引。

整型索引的底层逻辑

PyTorch中的整型索引(也称为高级索引或花式索引)允许使用整数数组来选择数据。整型索引可以在多个维度上非连续地选择数据,并且索引数组不需要与被索引数组的形状相匹配。

对于整型索引的实现,当你提供整数数组或整数张量给PyTorch张量时,底层实现会在C++层处理索引操作。以下是大致的实现步骤:

  1. 分析索引指令:PyTorch检测你提供的索引,并将其识别为整型索引操作。

  2. 内存分配:基于索引操作,PyTorch会分配一个新的内存空间来存储索引后得到的张量。

  3. 数据拷贝:PyTorch会遍历索引张量中的每个元素,并且在原始张量中查找对应位置的元素,然后将这些找到的元素复制到步骤2中分配的内存空间中。

  4. 返回新张量:将拷贝填充后的内存空间包装成一个新的PyTorch张量对象,然后返回该张量。

整型索引操作是一个相对“昂贵”的操作,因为它通常涉及数据的复制而不是简单的视图或内存共享。这意呀着整型索引得到的结果通常都是一个新的张量,与原始张量不共享数据。

以下是在Python中使用PyTorch进行整型索引的简单示例:

import torch

# 创建一个2维张量
data = torch.tensor([[1, 2], [3, 4], [5, 6]])

# 使用整型索引
indices = torch.tensor([0, 2])
selected_data = data[indices]

print(selected_data)  # 输出张量:[[1, 2], [5, 6]]

在这个示例中,indices张量包含了想要选择的行的索引。在内部,selected_data的创建涉及到遍历indices数组,并且从data张量中抽取出相应的行,然后将这些行组合成一个新的张量。

实际上的C++实现比这个过程要复杂,因为PyTorch的内核会优化这些操作,可能会并行化遍历索引以加快速度。此外,它还需要处理各种边缘情况,并确保在并发环境下的安全性和效率。如果你想了解底层的实现细节,可以查看PyTorch的开源代码,特别是在其GitHub仓库中与张量索引相关的部分。

你可能感兴趣的:(自动驾驶,深度学习,经验分享,pytorch)