pytorch scatter_的用法及含义

scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中
这个函数可以从转换成onehot编码来理解。
看下面代码:

index = torch.tensor([1,2,1,2,0])
torch.zeros(5,3).scatter_(1, index.unsqueeze(1), 1)  
# tensor([[0., 1., 0.],
#         [0., 0., 1.],
#         [0., 1., 0.],
#         [0., 0., 1.],
#         [1., 0., 0.]])

简要说明:这段代码的目的就是将李表[1,2,1,2,0]转成one-hot编码的形式。

因为有5个数据,然后数值范围从0~2,所以需要设置3列,所以目标矩阵应该是5x3。

scatter_中第一个1表示沿着维度1的方向也就是列的方向,第二个参数表示需要填值的索引,第三个参数表示填的值。

比如以输出结果的第一行为例,原本torch.zeros使得第一行的元素都是0,但是scatter_的第二个输入参数第一行是1(因为经过unsqueeze后第一行只有一个元素了),所以输出结果的第一行的第1个元素(从0开始)应该填上scatter最后一个参数所表示的值,剩下的以此类推。

不过其实得到onehot编码可以用pandas.get_dummies

index = [1,2,1,2,0]
pd.get_dummies(index)

你可能感兴趣的:(深度学习,pytorch)