def get_one_hot(labels, nb_classes):
res = np.eye(nb_classes)[np.array(labels).reshape(-1)]
return res.reshape(list(labels.shape)+[nb_classes])
解释:
np.array(labels).reshape(-1)
是将labels展平, 比如将[[2,1],[3,2],[0,0]]
(shape为[3,2])展平为[2,1,3,2,0,0]
(shape为[6, ]
)res = np.eye(nb_classes)[np.array(labels).reshape(-1)]
根据展平后的结果, 取np.eyes()中的对应行, 得到新的矩阵