pytorch中scatter_()函数用法

scatter_(dimindexsrc) → Tensor

首先看一下这个函数的接口,需要三个输入:1)维度dim; 2)索引数组index; 3)原数组src,为了方便理解,我们后面把src换成input表示。最终的输出是新的output数组。即 scatter_(dim, index, input) → Tensor

下面依次介绍:

1)维度dim:整数,可以是0,1,2,3...

2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置

3)原数组input:也是一个tensor,其中的数据类型任意

先说一下这个函数是干sa的,在我看来,这个scatter_函数就是把input数组中的数据进行重新分配。index中的数表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充为0。

注意scatter_函数是inplace操作。

比如说下面这段代码:

import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

运行结果如下:

pytorch中scatter_()函数用法_第1张图片

下面,详细说一下为什么会是这样的结果。

前面说了,scatter是input数组根据index数组对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。

input:

       tensor([[-0.9817, -2.3192,  0.1756,  1.2975],
                    [ 0.8049,  0.4067, -0.0477, -0.3837]])

index:

        torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output:

        tensor([[ 1.2975, -2.3192,  0.1756, -0.9817,  0.0000],
                    [-0.0477,  0.8049,  0.4067, -0.3837,  0.0000]])

首先,对input[0][0]进行重分配。

NOTE: 符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。

input[0][0] -> output[0][index[0][0]] = output[0][3]

数据位置发生的变化都是在第1维上,第0维不变。

input[0][1] -> output[0][index[0][1]] = output[0][1]

input[0][2] -> output[0][index[0][2]] = output[0][2]

input[0][3] -> output[0][index[0][3]] = output[0][0]

 

为了方便理解,我是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0
 

你可能感兴趣的:(pytorch)