Pytorch scatter_ 理解

scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。

>>> x = torch.rand(2, 5)
>>> x
 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

index的shape刚好与x的shape对应,也就是index中每个元素指定x中一个数据的填充位置。dim=0,表示按行填充,主要理解按行填充。举例index中的第0行第2列的值为2,表示在第2行(从0开始)进行填充,对应到input = zeros(3, 5)中就是位置(2,2)。所以此处要求input的列数要与x列数相同,而index中的最大值应与zeros(3, 5)行数相一致。

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

同上理,可以把1.23看成[[1.23], [1.23]]。此处按列填充,index中的2对应zeros(2, 4)的(0,2)位置。

>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
 0.0000  0.0000  1.2300  0.0000
 0.0000  0.0000  0.0000  1.2300
[torch.FloatTensor of size 2x4]

综上,几点要注意:

  • index的shape要与填充数据src的shape一致,如果不一致,将进行广播

  • index中的索引指的是要把src中对应位置的数据按照指定那个维度(即dim)填充到原数据input中,我们知道了要填充的数据是什么,填充到input的哪行那列呢,dim指定哪个维度,这个维度就是index索引值,另一个维度就是这个索引在index中的位置。

  • scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会。PyTorch 中,一般函数加下划线代表直接在原来的 Tensor 上修改

  • scatter() 一般可以用来对标签进行 one-hot 编码,这就是一个典型的用标量来修改张量的一个例子

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
#        [0],
#        [3],
#        [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
#        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
#        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

转载于https://blog.csdn.net/qq_16234613/article/details/79827006
转载于https://www.cnblogs.com/dogecheng/p/11938009.html

你可能感兴趣的:(Pytorch scatter_ 理解)