pytorch分类任务时将标签转成tensor类型

pytorch分类任务时,需要将标签转成可用于学习的tensor类型。标签可分为两种,一种是常用于多分类的整型数字标签(类似0,1,2,3);另一种是one-hot类型,假设共有三个类别,那么1,2,3对应的one-hot 分别是0 0 0, 0 1 0, 0 0 1。
 

一、one-hot类型的标签直接转tensor类型

我们通常会将one-hot类型的标签以list的形式读取,然后直接就可以使用torch.LongTensor(label)进行转换,类型为torch.int64

label=[1,0,0]
label_tensor=torch.LongTensor(label)
print(label_tensor)
print(label_tensor.dtype)

#tensor([1, 0, 0])
#torch.int64

二、int类型的标签转成tensor类型

之前一直使用的one-hot类型的标签,后来使用整型的标签,直接使用torch.LongTensor(label)发现老是报错 stack expects each tensor to be equal size, but got XXX at entry X and YYY at entry Y意思就是标签的维度不一致,就很纳闷,标签就是一个数字,怎么就维度不一致了,后来查看了torch.LongTensor的官方代码,发现了问题所在。

torch.LongTensor()的传入参数可以是一个列表,也可以是一个维度值,当我把单个数字传入时,就生成了不同维度的一个列表,并且值也是随机的,如下所示。并不是我想象中的tensor([4.0]),而是一个四维的列表,值是随机的,害!

label=4
label_tensor=torch.LongTensor(label)
print(label_tensor)
print(label_tensor.dtype)

#tensor([140265829301552, 140265829301552,  837,  94241316263856])
#torch.int64

正确的方式,应该是先将单个python类型的数据转成numpy,然后再转成tensor类型,如下所示

label=4
label_numpy=np.array(label)
label_tensor=torch.from_numpy(label_numpy)
print(label_tensor)
print(label_tensor.dtype)

#tensor(4)
#torch.int64

你可能感兴趣的:(pytorch,分类,深度学习)