pytorch中的scatter讲解

本文借从label获得one hot编码的例子来说明scatter的使用。

scatter的初步认识

首先我们来看看scatter的函数声明。其函数声明如下

tensor.scatter_(dim, index, src)

scatter函数涉及4个对象,分别是tensor,dim,index,src。该函数的主要作用是将src中的值根据dim和index填充到tensor中。这里有两个关键部分,1)将src中的值填充到tensor中;2)dim和index指定的位置,一个是src中的位置,另一个是tensor中的位置。

一个使用scatter的例子:从label获得one hot编码

下面以使用scatter获得label的one hot编码为例,讲解scatter的用法。

一般来说,我们有一个label列表 y = { y 1 , y 2 , ⋯   , y n } y=\{y_1, y_2, \cdots,y_n\} y={y1,y2,,yn}。我们想要label的one hot编码形式 y ^ = [ [ 0 , 0 , 0 , 1 , 0 , 0 , ⋯   , 0 ] , ⋯   ] \hat{y}=[[0,0,0,1,0,0,\cdots,0], \cdots] y^=[[0,0,0,1,0,0,,0],]。那么如何使用scatter将 y y y变成 y ^ \hat{y} y^?

下面是scatter的具体使用方法。

y_hat = torch.zeors(B, class_num).scatter_(dim=1, index=y.view(-1, 1), src=1.)

其中y_hat y ^ \hat{y} y^y y y y

在这个例子中,是将src中的1填充到tensor中。具体过程见下面的图。

pytorch中的scatter讲解_第1张图片
总结一下这个操作过程。该过程中,根据index和dim,将src中的值全部填充到tensor中。

下面详细说明index和dim这两个参数是如何指定相应位置。
首先是dim,这个参数是与tensor有关,即沿着tensor的那个维度(轴)来展开。
而index既与tensor有关,也与src有关。index中的值和dim两者确定了tensor中的位置。而index的值所在的位置则指明了src中的位置。简单来说,index和src两者的位置一 一对应。

一个使用scatter的复杂例子:


src = [
		[0.4319, 0.6500, 0.4080, 0.8760, 0.2355],
		[0.2609, 0.4711, 0.8486, 0.8573, 0.1029]
	  ]


tensor = [
            [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
            [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
        ]

dim=0

index = [
            [0, 1, 2, 0, 0],
            [2, 0, 0, 1, 2]
        ]

result = tensor.scatter_(dim, index, src)

result = [
            [0.4319, 0.4711, 0.8486, 0.8760, 0.2355],
            [0.0000, 0.6500, 0.0000, 0.8573, 0.0000],
            [0.2609, 0.0000, 0.4080, 0.0000, 0.1029]
        ]

具体过程见图,图中以index中的第一行为例。

pytorch中的scatter讲解_第2张图片

从上图的操作过程可以看出,index和src两者的形状需要一致。

index根据dim展开即可得到src中的值在tensor中的位置。
index展开后的结果是

[
    [(0,0), (1,1), (2,2), (0,3), (0,4)],
    [(2,0), (0,1), (0,2), (1,3), (2,4)]
]

上面就是使用scatter的全部内容。

你可能感兴趣的:(pytorch,pytorch,python,人工智能)