pytroch中的scatter_add_函数解析

关于这个函数,很少博客详细的介绍。下面就我个人理解简单介绍下。
函数:self_tensor.scatter_add_(dim, index_tensor, other_tensor) → 输出tensor
该函数的意思是:将other_tensor中的数据,按照index_tensor中的索引位置,添加至self_tensor矩阵中。
参数:
dim:表示需要改变的维度,但是注意,假如dim=1,并不是说self_tensor在dim=0上的数据不会改变,这个dim只是在取矩阵数据时不固定dim=1维的索引,使用index_tensor矩阵中的索引。可能这样说还是不太理解,下面会用例子说明。
其中self_tensor表示我们需要改变的tensor矩阵
index_tensor:索引矩阵;
other_tensor:需要添加到self_tensor中的tensor
要求:
1、self_tensor,index_tensor, other_tensor 的维度需要相同,即self.tensor.dim() = index_tensor.dim() = other_tensor.dim();
2、假设dim=d,那么index_tensor矩阵中的所有数据必须小于d-1;
3、假设dim=d,index_tensor矩阵在d维度上的size必须小于self_tensor和other_tensor的size;即index.size(d) <= other_tensor.size(d) 且index.size(d) <= self_tensor.size(d)

  • 三维计算公式:

    self[index[i][j][k]][j][k] += other[i][j][k] # 如果 dim == 0

    self[i][index[i][j][k]][k] += other[i][j][k] # 如果 dim == 1

    self[i][j][index[i][j][k]] += other[i][j][k] # 如果 dim == 2

  • 二维计算公式:

    self[index[i][j]][j] += other[i][j] # 如果 dim == 0

    self[i][index[i][j]] += other[i][j] # 如果 dim == 1

      index_tensor = torch.tensor([[0,1],[1,1]])
      print('index_tensor: \n', index_tensor)
      self_tensor = torch.arange(0, 4).view(2, 2)
      print('self_tensor: \n', self_tensor)
      other_tensor = torch.arange(5, 9).view(2, 2)
      print('other_tensor: \n', other_tensor)
      dim = 0
      for i in range(index_tensor.size(0)):
          for j in range(index_tensor.size(1)):
              replace_index = index_tensor[i][j]
              if dim == 0:
                  # self矩阵的第0维索引
                  self_tensor[replace_index][j] += other_tensor[i][j]
              elif dim == 1:
                  # self矩阵的第1维索引
                  self_tensor[i][replace_index] += other_tensor[i][j]       
      print(self_tensor)
    

结果:

	index_tensor: 
 tensor([[0, 1],
        [1, 1]])
self_tensor: 
 tensor([[0, 1],
        [2, 3]])
other_tensor: 
 tensor([[5, 6],
        [7, 8]])
tensor([[ 5,  1],
        [ 9, 17]])

使用函数计算:

index_tensor = torch.tensor([[0,1],[1,1]])
print('index_tensor: \n', index_tensor)
self_tensor = torch.arange(0, 4).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
self_tensor.scatter_add_(0, index_tensor, other_tensor) 
print(self_tensor)

结果:

index_tensor: 
 tensor([[0, 1],
        [1, 1]])
self_tensor: 
 tensor([[0, 1],
        [2, 3]])
other_tensor: 
 tensor([[5, 6],
        [7, 8]])
tensor([[ 5,  1],
        [ 9, 17]])

你可能感兴趣的:(NLP,pytorch)