torch.scatter

本文目录

    • 一、函数简介
    • 二、二维举例
    • 三、详解执行过程
      • 1. 第一步
      • 2. 第二步
      • 3. 第三步
      • 4. 问题

一、函数简介

torch.scatter(input, dim, index, src)

  • dim ([int]) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
  • src ([Tensor] or [float] – the source element(s) to scatter.
  • reduce ([str], optional) – reduction operation to apply, can be either 'add' or 'multiply'.

将src中的数据根据index中的索引按照dim的方向填入到input中

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

看了上述的官方文档还是不理解,我们继续看官方的例子,这里官方只给了三维,我在这里又加入了二维,在这之前有一个规定

  • 对任意维度d:index.size(d) <= src.size(d)
  • 对d!=dim的维度:index.size(d) <= self.size(d)

二、二维举例

self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1

先上代码

torch.manual_seed(0)
x = torch.arange(0, 12).reshape(2, 6)
x= x.type(torch.float32)
print(x)
'''
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.]])
'''


index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index)
'''
tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])
'''


y = torch.zeros(3, 6)
y = torch.scatter(y, 0, index, x)
print(y)
'''
tensor([[ 0.,  7.,  8.,  3.,  4.,  0.],
        [ 0.,  1.,  0.,  9.,  0.,  0.],
        [ 6.,  0.,  2.,  0., 10.,  0.]])
'''


'''in-place operation'''
yy = torch.zeros(3, 6)
yy.scatter_(0, index, x)
print(yy)
'''
tensor([[ 0.,  7.,  8.,  3.,  4.,  0.],
        [ 0.,  1.,  0.,  9.,  0.,  0.],
        [ 6.,  0.,  2.,  0., 10.,  0.]])
'''

下面将上述的执行过程绘制出来

三、详解执行过程

1. 第一步

首先下面是我们的初始化,我们初始化了srcindexinput并且设置了dim=0
填充公式为self [ index[i][j] ][j] = src[i][j]

因为dim=0,所以需要填充的input的行的索引就由index数值也就是index[i][j]来确定,需要填充的input的列的索引对应于index的列,往self里面填充的具体数值由index对应的src来确定

看下面例子序号3

  • 需要填充的input的行的索引:行=index[i][j]=index[0][2]=2
  • 需要填充的input的列的索引:列=index列=j=2
  • self填充的具体数值:self[行][列]=self[2][2]=src[i][j]=2
torch.scatter_第1张图片

2. 第二步

下面我们继续上述步骤

可发现当我们进行到第六步的时候,index[0][5]并不存在,所以直接跳过就可以了

torch.scatter_第2张图片

3. 第三步

在这一步我们将input填充完毕

如图所示,这里我们取图中的序号11进行验证

序号11

  • 需要填充的input的行的索引:行=index[i][j]=index[1][4]=2
  • 需要填充的input的列的索引:列=index列=j=4
  • self填充的具体数值:self[行][列]=self[2][4]=src[i][j]=10
  • 所以在self的第二行,第四列填入10
torch.scatter_第3张图片

4. 问题

为什么在第二步我们遇到的问题吗:当我们进行到序号6的时候,index[0][5]并不存在,我们选择了跳过

可以跳过而没有报错呢,因为最初的文档对src, index, self的维度有过定义

  • 对任意维度d:index.size(d) <= src.size(d)
  • 对d!=dim的维度:index.size(d) <= self.size(d)

所以index的维度是可以小于src的维度的,关系如下torch.scatter_第4张图片

你可能感兴趣的:(PyTorch,pytorch,深度学习,python)