tf.one_hot

tensorflow.one_hot

tf.one_hot(labels, depth, axis)

作用如下:

tf.one_hot_第1张图片

其中:label是y,depth是右边矩阵的深度(也就是有多少个类)

import numpy as np
import tensorflow as tf

def one_hot_matrix(lable,clas):
    one_hot_matrix = tf.one_hot(indices=lable , depth=clas , axis=0)
    with tf.Session() as sess:
        reslut = sess.run(one_hot_matrix)

    return reslut

y = np.array([1,2,3,4,5])
c = 6
print(one_hot_matrix(y,c))

 

你可能感兴趣的:(Tensorflow)