torch: Invalid index in scatter at c:\a\w\1\s\windows\pytorch\aten\src\th\gene 解决

  • 文档了解:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
scatter_(dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

  • 官网的例子就可以看出来,dim=0的时候,src中的元素对应到index的元素,放在某行,这个由index确定;dim=1的时候src中的值代表对应某列。

For a 3-D tensor, self is updated as:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
  • 我有一个a需要独热处理,其中a中的每一个元素都是一个属性矩阵9x9;所以呢,独热处理的是这个矩阵,
a = torch.from_numpy(a).to('cpu').long() 
print(a.shape)

torch.Size([16, 9, 9])

独热的函数如下:

    def label2onehot(labels, dim):
        """Convert label indices to one-hot vectors."""
        # list(labels.size()) 返回labels的shape
        out = torch.zeros(list(labels.size())+[dim]).to('cpu')
        print(out.shape)
        # labels.unsqueeze(-1) 在最后增加一个维度;torch.Size([10, 9, 9]) to torch.Size([10, 9, 9, 1])
        print(labels.unsqueeze(-1).shape)
        print(len(out.size())-1)
        out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
        return out

当使用

a_tensor = label2onehot(a,4)

报错为:RuntimeError: Invalid index in scatter at c:\a\w\1\s\windows\pytorch\aten\src\th\generic/THTensorEvenMoreMath.cpp:549。

后来发现a中有的值为4,所以独热的时候最后一维应该是5才对,这样就正确了:

a_tensor = label2onehot(a,5)
a_tensor.shape

torch.Size([16, 9, 9, 5])

为什么需要5维,看这个:
https://zhuanlan.zhihu.com/p/35287916

你可能感兴趣的:(bugs,pytorch)