在多分类任务中,通常将目标转换成独热编码来进行训练,本文将介绍
使用 scatter_()
来转换:
如使用独热进行编码:
label = torch.LongTensor([[1], [5], [7], [2], [9]])
'''
tensor([[1.],
[5.],
[7.],
[2.],
[9.]])
'''
one_hot_label = torch.zeros(5, 10).scatter_(1, label, 1)
'''
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''
通过 argmax()
得到先前的向量。
如:
results = one_hot_label.argmax(dim=1, keepdim=True)
'''
tensor([[1],
[5],
[7],
[2],
[9]])
'''
于是就成功复原了
import tensorflow as tf
label = tf.stack(5)
one_hot_label = tf.one_hot(label, 10)
sess = tf.Session()
print("label: ", sess.run(label))
print("one_hot_label: ", sess.run(one_hot_label))
# 输出
label: 5
one_hot_label: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]