pytorch实现one-hot embedding

import torch


def one_hot_embedding(labels, num_classes):
    '''Embedding labels to one-hot.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N,#classes].
    '''
    y = torch.eye(num_classes, device='cpu')  # [D,D]
    return y[labels]  # [N,D]
    '''
    创建   num_classes维度的单位矩阵
    然后取出单位矩阵的某一行,即可以作为 one-hot vector
    '''
if __name__=='__main__':
    labels=2
    num_classes=10
    one_hot_vector=one_hot_embedding(labels,num_classes)
    print(one_hot_vector)
    #tensor([0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])

你可能感兴趣的:(pytorch)