Python 生成one_hot标签和恢复

    首先生成一个主对角全为1的其余全为0的矩阵,比如有n个分类就是n * n,效果如下:

Python 生成one_hot标签和恢复_第1张图片

    随后根据标签列表(或者numpy数组)选取合适的行,比如标签是[9, 1, 0, 0], 那么就会选择上图矩阵中对应的9、1、0、0行,得到one_hot标签,如果不熟悉numpy数组的列表切片的(就是说numpy_array[slice]中的slice是列表) ,可以看下这篇Python Numpy数组使用列表索引

    恢复的话就是找列表中为1的下标即可。

    代码如下:

# encoding = utf-8
'''
    author : James-J
    time : 2019/05/29
'''

import numpy as np

if __name__ == '__main__':
    one_hot = np.eye(10) # 10*10的矩阵 对角线上是1
    print('np.eye(10)\n', one_hot)
    # 两种方法 传一维的numpy数组和列表都可以
    label = np.array([1, 4, 8, 9, 5, 0])
    one_hot_label = one_hot[label.astype(np.int32)] # 表示选取矩阵上面的第几行
    # label = [1, 4, 8, 9, 5, 0]
    # one_hot_label = one_hot[label]
    print('-----------------one_hot--------------------')
    print(one_hot_label)

    label = [one_label.tolist().index(1) for one_label in one_hot_label] # 找到下标是1的位置
    print('------------------label---------------------')
    print(label)

    得到的结果:

Python 生成one_hot标签和恢复_第2张图片

 

 

你可能感兴趣的:(Python)