[pytorch]如何将label转化成onehot编码

之前用octave学习神经网络的时候,用逻辑回归,激活函数是sigmoid,损失函数是交叉熵损失函数,那个时候不用任何框架,需要把label转化成onehot编码:

c =[1:10]
y =(y==c)

只需要两行代码,很简单。
现在使用pytorch框架,刚开始学,情况比较复杂,废了半天时间才能把自己的数据正确导入程序(需要用固定的torch容器来装),之后训练神经网路的时候开始使用交叉熵损失函数(CrossEntropyLoss),没有发现错误,改用MSE损失函数后反而会报错。后来知道,使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码,转化方法如下(https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3):

[pytorch]如何将label转化成onehot编码_第1张图片
Paste_Image.png

你可能感兴趣的:([pytorch]如何将label转化成onehot编码)