pytorch:scatter_函数生成one_hot向量

实不相瞒,这个函数我看了整整一个小时才弄明白,真是让人抓狂。
先来看一下该函数的参数。

scatter_(dim, index, src)
dim:维度,表示在第几维上操作;
index:索引,后面再解释;
src:用来填充的tensor。

这个函数的主要作用是按照一定规则用src中的值去填充/替换原来tensor的值。
举个例子:

src = torch.tensor([[1.000], [2.000]])
index = torch.tensor([[1], [2]])
output= torch.zeros(2, 4)
output.scatter_(1, index, src)
print(output)

首先,src和index的维度应该是一致的。由于dim=1,所以行数不变,只变化列。
index中的值:
在这里插入图片描述
变化的列值即本身。
现在我们一共得到两个位置,[0][1]和[1][2]。这个位置就是我们需要填充的output的位置。那么用什么值来填充呢。答案就是index原始位置对应的src的位置。
output[0][1]=src[0][0]=1.000
output[1][2]=src[1][0]=2.000
output其他位置的值不变。最终的输出如下:
在这里插入图片描述
明白了上面这个例子,我们再来看如何用scatter_函数生成one-hot向量。代码如下:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

dim=1,行不变,列变化。
1->原始位置[0][0]->变化列->[0][1]
2->原始位置[1][0]->变化列->[1][2]
0->原始位置[2][0]->变化列->[2][0]
3->原始位置[3][0]->变化列->[3][3]

output[0][1]=src[0][0]=1
output[1][2]=src[1][0]=1
output[2][0]=src[2][0]=1
output[3][3]=src[3][0]=1
这里需要注意一个问题,src只有一个数,根据广播机制,我们将1复制成4行1列再填充output中的相应位置。
结果如下:
在这里插入图片描述
以上。

你可能感兴趣的:(pytorch)