pytorch one-hot 小技巧

首先介绍一下np.eye()

numpy.eye(N,M=None,k=0,dtype=,order='C)

参数 类型 Value
N int 表示的是输出的行数
M int型 输出的列数,默认为N
k int型 对角线的下标,默认为0表示的是主对角线,负数表示的是低对角,正数表示的是高对角。
dtype $1 数据的类型,可选项,返回的数据的数据类型
order {‘C’,‘F’} 可选项,也就是输出的数组的形式是按照C语言的行优先’C’,还是按照Fortran形式的列优先‘F’存储在内存中

例子:

import numpy as np

labels = np.array([[1], [2], [0], [1]])
print("labels的大小:", labels.shape, "\n")

# 因为我们的类别是从0-2,所以这里是3个类
a = np.eye(3)[1]
print("如果对应的类别号是1,那么转成one-hot的形式", a, "\n")

a = np.eye(3)[2]
print("如果对应的类别号是2,那么转成one-hot的形式", a, "\n")

a = np.eye(3)[1, 0]
print("1转成one-hot的数组的第一个数字是:", a, "\n")

# 这里和上面的结果的区别,注意!!!
a = np.eye(3)[[1, 2, 0, 1]]
print("如果对应的类别号是1,2,0,1,那么转成one-hot的形式\n", a)

res = np.eye(3)[labels.reshape(-1)]
print("labels转成one-hot形式的结果:\n", res, "\n")
print("labels转化成one-hot后的大小:", res.shape)

pytorch one-hot 小技巧_第1张图片

你可能感兴趣的:(pytorch,numpy,python)