【Pytorch | Tensorflow】--- label与one-hot独热编码向量之间的相互转换

在多分类任务中,通常将目标转换成独热编码来进行训练,本文将介绍

一. Pytorch操作

1.1. label → \rightarrow one-hot向量

使用 scatter_() 来转换:

如使用独热进行编码:

label = torch.LongTensor([[1], [5], [7], [2], [9]])
'''
tensor([[1.],
        [5.],
        [7.],
        [2.],
        [9.]])
'''

one_hot_label = torch.zeros(5, 10).scatter_(1, label, 1)
'''
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''

1.2. one-hot编码 → \rightarrow label

通过 argmax() 得到先前的向量。

如:

results = one_hot_label.argmax(dim=1, keepdim=True)
'''
tensor([[1],
        [5],
        [7],
        [2],
        [9]])
'''

于是就成功复原了


二. tensorflow操作

2.1. label → \rightarrow one-hot向量


import tensorflow as tf

label = tf.stack(5)
one_hot_label = tf.one_hot(label, 10)

sess = tf.Session()
print("label: ", sess.run(label))
print("one_hot_label: ", sess.run(one_hot_label))
# 输出
label:  5
one_hot_label:  [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

你可能感兴趣的:(Python)