无脑入门pytorch系列(四)—— scatter_

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo
  • one-hot

官方定义

torch.tensor.scatter_是PyTorch中的一个函数,用于将指定索引处的值替换为给定的值。

函数定义:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

官方解释:

  • 将张量src中的所有值写入索引张量中指定的index处的self。

  • 对于src中的每个值,它的输出索引由其在src中的索引(dimension != dim)和在index中对应的值(dimension = dim)指定。

非常难以理解,十分抽象,从我个人的角度来说就是:

  • 第一个参数dim表示维度,即在第几维度处理数据,保持其它维度不变。
  • reduce参数是一个可选参数,用于指定如何在执行散射(scatter)操作时对重复的索引值进行合并或聚合。
  • index则是需要填充的列的索引,即根据维度从src中取对应的值填充到tensor中去。

怎么映射的,比如一个一个3维张量:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

官方的文档如下,TORCH.TENSOR.SCATTER_:

无脑入门pytorch系列(四)—— scatter__第1张图片

即使如此理解起来也是很复杂,下面从例子中去理解:

demo

下面是一个官方文档给出的例子:

import torch

src = torch.Tensor([[-1.0276,  0.2673, -1.1752, -0.8823],
        [-0.6447, -0.8256,  0.1542, -0.4242]])
print(src)

output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output = output.scatter(1, index, src)
print(output)

输出的结果:

无脑入门pytorch系列(四)—— scatter__第2张图片

我们一步步理解代码:

  1. 首先,定义了一个src张量,后续output即从src中取值。
  2. 其次,定义了output,其值为二行五列的全零张量,后续对output进行修改。
  3. 接着,定义了index,即从src取值的索引。
  4. 最后,根据index从src取值填充到output中,即完成操作。

那么具体是如何取值的呢?

首先,dim = 1,意味着从维度值为1的地方取值,维度值为0的地方不变,那就是:

self[i][index[i][j]] = src[i][j]  # if dim == 1

具体来说:

i = 0, j = 0时,output[0][index[0][0]] = src[0][0],因为index[0][0] = 3,所以output[0][3] = src[0][0] = -1.0276,这时候我们检查输出的output值,确实是-1.0276

同理:

i = 0, j = 1: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673

i = 0, j = 2: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752

one-hot

作者在学习该函数时实在遇到one-hot编码时遇到的,而该函数在one-hot中应用很广:

index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

无脑入门pytorch系列(四)—— scatter__第3张图片

你可能感兴趣的:(python,#,无脑入门pytorch系列,pytorch,人工智能,python)