torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum

ref:

  • https://pytorch.org/docs/1.4.0/tensors.html?highlight=index_add_#torch.Tensor.index_add_
  • https://blog.csdn.net/weixin_44289071/article/details/103882658

torch.Tensor.index_add_能实现指定行或列的内容相加的功能,类似于tensorflow中tf.unsorted_segment_sum函数,可以用在比如实例分割中进行特征聚合的步骤。比如一个N*C的feature根据实例label可以将属于同一实例的点的特征聚合起来,得到Ins_num*C的聚合特征。

1. 函数的参数

torch.Tensor.index_add_函数,pytorch中的tf.unsorted_segment_sum_第1张图片

  • dim:这个参数表明你要沿着哪个维度索引;
  • index:包含索引的tensor;
  • tensor:被索引出来去相加的tensor;
  • 注意事项:x相加前后的shape保持不变,被索引的tensor在被索引的维度(第dim维)之外的维度上与tensor的对应维度必须保持一致,且index中的值最大不能超过x在被索引的维度上的最大维数,index的长度必须和tensor[dim]相同。假如x的shape(N, C),索引的维度为第0维(dim=0),那么被索引的tensor的dim=1的维度也必须为C,index的值必须介于0C-1之间,index的长度必须和被索引的tensor的dim=0的数字相同。

2. 使用示例

import torch
x = torch.ones(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=torch.float)
index = torch.tensor([0, 2, 4, 2])
new_x = x.index_add_(0, index, t)  # 把x的每一行加上index从t中索引出来的值
print('new_x: {}'.format(new_x))
new_x: tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [15., 17., 19.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.]])

解释一下:x.index_add_()表示在x的每一行加上index从t中索引出来的值,这个例子中x初始为一个5行3列全为1的tensor。

  • 如何确定x的第i行要加上的值,首先通过index[j]=i找到所有满足条件的j,再把所有的t[j]加上x[i]就得到新的x[i]。一行行来看。
  • new_x的第0行:首先去找index中值为0的的索引,找index[j]=0,所以j=0,新的new_x[0]=t[0]+x[0],即new_x[0]=[1, 2, 3]+[1, 1, 1]=[2, 3, 4]
  • new_x的第1行:首先去找index中值为1的的索引,找不到对应的j,所以没有东西可以加上,new_x[1]保持不变。
  • new_x的第2行:首先去找index中值为2的的索引,找index[j]=2,所以j=1, 3,新的new_x[2]=t[1]+t[3]+x[0],即new_x[2]=[4, 5, 6]+[10, 11, 12]+[1, 1, 1]==[15, 17, 19]
  • new_x的第3行:首先去找index中值为3的的索引,找不到对应的j,所以没有东西可以加上,new_x[3]保持不变。
  • new_x的第4行:首先去找index中值为4的的索引,找index[j]=4,所以j=2,新的new_x[4]=t[2]+x[4],即new_x[4]=[7, 8, 9]+[1, 1, 1]=[8, 9, 10]

3. 使用(我自己看的)

  • 可以根据自己的工程需要去分配每一个输入的值。
  • 假设我需要聚合属于相同实例的点的特征,我可以把初始的x设为shape为(Ins_num, C)的全0数组;t设置为需要被索引的feature,其shape为(N, C);索引就可以为实例label,shape为(N, )
  • label中为每个点所属的实例类别,其值为0Ins_num-1。(如果这个当中某些背景点的label为-100或者其他值就要注意了,索引会报错,记得处理一下)
  • 那么最终就可以得到一个shape为(Ins_num, C)的新tensor,每i行代表对应实例label为i的所有点相加后的特征。

你可能感兴趣的:(pytorch,tensorflow,深度学习,机器学习,pytorch,tensorflow)