scatter_add()函数

Pytorch 的 Tensor 用法

官方解释:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_add#torch.Tensor.scatter_add_

函数参数:scatter_add_(dim,  indexTensor,  otherTensor) → 输出Tensor

函数用法:selfTensor.scatter_add_(dim,  indexTensor,  otherTensor)

要求:

  1. selfindex and other should have same number of dimensions.
  2. index.size(d) <= other.size(d) for all dimensions d
  3. index.size(d) <= self.size(d) for all dimensions d != dim.
  4. as for gather(), the values of index must be between 0 and self.size(dim) - 1
  5. all values in a row along the specified dimension dim must be unique.

示例代码:final_dist = vocab_dist_.scatter_add(1,  enc_batch_extend_vocab,  attn_dist_)

该函数将 otherTensor 的所有值加到 selfTensor 中,加入位置由 indexTensor 指明。

self[ index[i][j][k] ][ j ][ k ] += other[ i ][ j ][ k ]  # if dim == 0

 

你可能感兴趣的:(scatter_add()函数)