本文借从label获得one hot编码的例子来说明scatter的使用。
首先我们来看看scatter的函数声明。其函数声明如下
tensor.scatter_(dim, index, src)
scatter函数涉及4个对象,分别是tensor,dim,index,src。该函数的主要作用是将src中的值根据dim和index填充到tensor中。这里有两个关键部分,1)将src中的值填充到tensor中;2)dim和index指定的位置,一个是src中的位置,另一个是tensor中的位置。
下面以使用scatter获得label的one hot编码为例,讲解scatter的用法。
一般来说,我们有一个label列表 y = { y 1 , y 2 , ⋯ , y n } y=\{y_1, y_2, \cdots,y_n\} y={y1,y2,⋯,yn}。我们想要label的one hot编码形式 y ^ = [ [ 0 , 0 , 0 , 1 , 0 , 0 , ⋯ , 0 ] , ⋯ ] \hat{y}=[[0,0,0,1,0,0,\cdots,0], \cdots] y^=[[0,0,0,1,0,0,⋯,0],⋯]。那么如何使用scatter将 y y y变成 y ^ \hat{y} y^?
下面是scatter的具体使用方法。
y_hat = torch.zeors(B, class_num).scatter_(dim=1, index=y.view(-1, 1), src=1.)
其中y_hat
是 y ^ \hat{y} y^,y
是 y y y。
在这个例子中,是将src中的1填充到tensor中。具体过程见下面的图。
总结一下这个操作过程。该过程中,根据index和dim,将src中的值全部填充到tensor中。
下面详细说明index和dim这两个参数是如何指定相应位置。
首先是dim,这个参数是与tensor有关,即沿着tensor的那个维度(轴)来展开。
而index既与tensor有关,也与src有关。index中的值和dim两者确定了tensor中的位置。而index的值所在的位置则指明了src中的位置。简单来说,index和src两者的位置一 一对应。
src = [
[0.4319, 0.6500, 0.4080, 0.8760, 0.2355],
[0.2609, 0.4711, 0.8486, 0.8573, 0.1029]
]
tensor = [
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
]
dim=0
index = [
[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]
]
result = tensor.scatter_(dim, index, src)
result = [
[0.4319, 0.4711, 0.8486, 0.8760, 0.2355],
[0.0000, 0.6500, 0.0000, 0.8573, 0.0000],
[0.2609, 0.0000, 0.4080, 0.0000, 0.1029]
]
具体过程见图,图中以index中的第一行为例。
从上图的操作过程可以看出,index和src两者的形状需要一致。
index根据dim展开即可得到src中的值在tensor中的位置。
index展开后的结果是
[
[(0,0), (1,1), (2,2), (0,3), (0,4)],
[(2,0), (0,1), (0,2), (1,3), (2,4)]
]
上面就是使用scatter的全部内容。