pytorch中torch.Tensor.scatter用法

函数 tensor.scatter_(dim, index, src)

返回值:返回一个根据index映射关系映射后的新的tensor 

参数解释:dim 变化的维度

                   index 映射关系

                   src 输入的tensor

代码示例:

import torch

x = torch.FloatTensor([[ 1, 2, 3, 4,5],

                                    [6, 7,8, 9,10]])

result = torch.zeros(3, 5)

indices = torch.tensor([[0, 1, 2, 0, 0], 

                                    [2, 0, 0, 1, 2]])

 result.scatter_(dim = 0, index = indices, src = x)

输出结果:

tensor([[ 1.,  7.,  8.,  4.,  5.],       

             [ 0.,  2.,  0.,  9.,  0.],       

             [ 6.,  0.,  3.,  0., 10.]])

代码解释:

    result.scatter_(dim = 0, index = indices, src = x)

    dim=0:

    那么转换过程中,只改src中各个元素的行,不改变列。从最后的结果可以看到:

    x中第一列 1和6,转换后输出结果中 1和6依然在第一列,只是行发生了变化。换言之,行根据index参数的映射关系进行了映射。那么怎么理解index的映射关系呢?

    indices = torch.tensor([[0, 1, 2, 0, 0], 

                                       [2, 0, 0, 1, 2]])

    既然列不变,不妨我们纵向来看:

    index参数中的 第0行第0列是0,那么就把src中的[0][0] =1,映射到 [0][0] = 1

    index参数中的 第1行第0列是2,那么就把src中的[1][0]  =6,映射到 [2][0]  = 6

    index第0列已经全部映射完毕,但是第0列还有一个[1][0]是空的,那就自动赋值为0

    于是,结果中的 第一列 分别为 1 0 6

那如果dim  = 1 会发生什么呢?先上代码

代码示例:

import torch

x = torch.FloatTensor([[ 1, 2, 3, 4,5],

                                    [6, 7,8, 9,10]])

result = torch.zeros(3, 5)

indices = torch.tensor([[0, 1, 2, 0, 0], 

                                    [2, 0, 0, 1, 2]])

result.scatter_(dim = 1, index = indices, src = x)

输出结果:

tensor([[ 5.,  2.,  3.,  0.,  0.],

           [ 8.,  9., 10.,  0.,  0.],

           [ 0.,  0.,  0.,  0.,  0.]])

再分析一次:

        dim  = 1,即映射关系是列映射关系。那么行不变。我们横向分析,以第0行为例:

       index参数中的 第0行第0列是0,那么就把src中的[0][0] =1,映射到 [0][0] = 1

       index参数中的 第0行第1列是1,那么就把src中的[0][1]  =2,映射到 [0][1]  =2

       index参数中的 第0行第2列是2,那么就把src中的[0][2] =3,映射到 [0][2] = 3

       index参数中的 第0行第3列是0,那么就把src中的[0][3]  =4,映射到 [0][0]  =4

       index参数中的 第0行第4列是0,那么就把src中的[0][4]  =5,映射到 [0][0]  =5

      index第0列已经全部映射完毕,未赋值的,自动赋0

      于是,结果中的 第1行结果为 5 2 3 0 0

那最后一行为什么都是0呢?

因为dim = 1,要求只改变第二维度,行维度不变,index中没有第2行的映射关系。也就是说,再映射过程中,源数据src的元素在哪一行,映射结果中就在哪一行。源数据中第二行没有元素,那映射结果中的第二行只能赋0。

注意:本文中下标都是0开始。

作者:LambAI https://www.bilibili.com/read/cv18787354?spm_id_from=333.999.list.card_article.click 出处:bilibili

你可能感兴趣的:(机器学习,pytorch,深度学习,python)