pytorch-全面讲解函数topk, scatter, gather

这三个函数在pytorch中关于矩阵操作的非常实用的函数。我认为要想熟练的使用pytorch,能够灵活的使用这三个函数是至关重要的

文章目录

    • 三者的相同点:维度->数据的映射方式
    • topk
    • gather
    • scatter

三者的相同点:维度->数据的映射方式

因为三者都存在相似的地方,所以我这里放在一起来讲。这个共同点就是index -> value的方式:这里以官方给的gather函数对应为例:

# for a 3-D data
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

这样一看,并不好理解,举个例子:

  • 关于shape的变化:
    • 输入为[3, 4, 1]的数据x | 它的index为[3, 2, 1]
    • x.gather(dim=1, index)输出维度为[3, 2, 1]。它保持另外两维不变,仅在这一维上操作。
  • 关于数据的变化
    • idx中的数据代表在指定维度上的index。 pytorch-全面讲解函数topk, scatter, gather_第1张图片

topk

其实前面讲的映射方式计算起来还是容易乱,不过幸好并不影响我们的使用。emm实在不能理解可以忽略,只需要知道在指定维度上操作即可

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
  • 主要用途:依照大小,从矩阵某维度取值和取索引。常与scatter、gather连用。
  • 函数返回两个变量:value和index
  • 维度变化:假设指定维度为1,则(b, n, m)-> (b, k, m)
  • 其它用途:topk的数据默认按照从大到小排列,因此我们可以当做矩阵中的数据排序来用,若largest=False则为升序:

pytorch-全面讲解函数topk, scatter, gather_第2张图片

gather

torch.gather(input, dim, index, out=None) → Tensor
  • 用途:依照index来对矩阵进行取值
  • 函数返回与输入idx维度相同的tensor
  • 维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n)。则输出为(b,k,n)。在维度1上按照index进行取值。

scatter

torch.scatter(input, dim, index, src) → Tensor
  • 用途:与gather类似,不过它并不用来取值。scatter用来更替矩阵中指定index位置的值。

  • 维度变化:假设指定dim=1,index=(b,k,n),input=(b,m,n),src=(b,m,n)。则输出为(b,k,n)。在维度1上按照index从src取值,然后替换到input上相同的index位置。

  • 两种用法:

    • 一般要求source的维度为input维度相同,如下例:

    pytorch-全面讲解函数topk, scatter, gather_第3张图片

    • 当然,也可以直接指定要替换的值,如下:

    pytorch-全面讲解函数topk, scatter, gather_第4张图片

你可能感兴趣的:(pytorch,python,人工智能,pytorch)