PyTorch内Tensor按索引赋值的方法比较

有很多时候,我们需要对深度学习过程中的tensor进行一些非整齐、离散化的赋值操作,例如我们让网络的一支输出可能的索引值,而另外一支可能需要去取对应索引值的内容。PyTorch提供了几种方法实现上述操作,但是其实际效果之间存在差异,在这里整理一下。

  1. scatter_(dim, index, src)
    按照index,将src的数据散放到self的'dim'维度中。例如,对于三维Tensor,效果如下:
    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
    
    • dim (int) - 要散布拷贝的维度
    • index (LongTensor) - 散布拷贝的索引
    • src (Tensor or float) - 要散布拷贝的源,可以是单个浮点值或是tensor
  2. index_fill_(dim, index, val)
    按照index,将val的值填充selfdim维度。效果如下:
    >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
    >>> index = torch.tensor([0, 2])
    >>> x.index_fill_(1, index, -1)
    tensor([[-1.,  2., -1.],
            [-1.,  5., -1.],
            [-1.,  8., -1.]])
    
    • dim (int) - 要填充的维度
    • index (LongTensor) - 要填充的索引
    • val (float) - 要填充的值
  3. index_put_(indices, value)
    按照indices,将val的值填充到self的对应位置。效果如下:
    >>> a = torch.zeros([5,5])
    >>> index = (torch.LongTensor([0,1]),torch.LongTensor([1,2])
    >>> a.index_put_(index), torch.Tensor([1,1]))
    tensor([[ 0.,  1.,  0.,  0.,  0.],
            [ 0.,  0.,  1.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.],
            [ 0.,  0.,  0.,  0.,  0.]])
    
    • indices (tuple of LongTensor) - 要填充的索引
    • value (Tensor) - 要填充的值组成的tensor

这三者的参数名相像,但实际上对各参数的定义有差别,要仔细跟据参数类型和例子好好分析。

你可能感兴趣的:(PyTorch内Tensor按索引赋值的方法比较)