解决RuntimeError: _thnn_mse_loss_forward is not implemented for type torch.cuda.LongTensor和scatter_方法

在PyTorch中遇到了如标题的问题,我使用的MSE损失函数,网上大多数给的是类型不匹配问题,在stackoverflow找到了问题的答案,这里出现的问题是因为loss需要one-hot类型的数据,而我们使用的是类别标签。

什么是one-hot?

一个例子解释什么是one-hot,对于5分类问题,我们使用[0,0,1,0,0]来表示这个实例是属于第三个类别的,等价于类别标签[2](从0对类别编码)。关于one-hot的好处,自行百度或google。

解决办法和scatter_函数介绍

我们需要将神经网络的预测out和实例本身的标签label变为one-hot形式,因为out和label里面存储的是最大值索引,所以变换依赖于Tensor对象的scatter_方法,在索引位置设为1,其他为0,关于方法介绍查看下面的链接
Pytorch 学习(5):Pytorch中的 torch.gather/scatter_ 聚集/分散操作
官方文档
关于如何使用该方法实现one-hot,见如下链接
转换为one-hot格式

在这里插入图片描述
scatter方法就是将src中的值按照index张量中的索引赋给当前张量相应位置的值,这里需要注意一件事情是scatter_方法要求
can be either empty or the same size of src. When empty, the operation returns identity,即index数组必须与src数组的维度一致,经过实验它还必须与源数组一样,原数组为2维,那么下标数组必须为2维(你可以这样想,如果源为二维,index数组为一维,假设dim为1,那么我们就不能确定那些行需要变换。其中每一维的数量可以不相等,但有相应约束,可看官方文档)。其中src可以为某一个浮点数,我估计它内部是进行了广播机制的,将这个浮点数扩展为输出数组的维度。综上,我们需要将上面的out和label数组变为二维(因为one-hot是二维的),可以调用out=out.reshape(种类数,1)。(PS:在pytorch中,一维是只输入一个int型就行了,只要输入了两个int数,那么一定是二维的,所以“(种类数,1)”得到二维张量)
在torch中,一个函数后面加上_符号,表示对自己作用,不加表示返回值为作用结果,而自身不改变。
这样就解决这个问题了。

你可能感兴趣的:(PyTorch)