pytorch中 nn.Embedding的原理

nn.Embedding 实质上是矩阵运算,实现维度转换

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
 max_norm=None,  norm_type=2.0,   scale_grad_by_freq=False, 
 sparse=False,  _weight=None)

num_embeddings : 输入数据的类别数
embedding_dim : 数据的编码维度

from torch import nn
z = nn.Embedding(3,2)
z.weight

输出

Parameter containing:
tensor([[ 0.5813, -0.4503],
        [-1.8539,  0.6905],
        [ 0.3107, -0.7194]], requires_grad=True)

y = z(torch.tensor([1,2,0]))

输出

y
tensor([[-1.8539,  0.6905],
        [ 0.3107, -0.7194],
        [ 0.5813, -0.4503]], grad_fn=)

由上面测试可见,该函数先将输入数据由索引编号转换为onehot编码的矩阵,然后右乘一个权重矩阵完成 输入的embedding,在训练过程中梯度更新权重,寻找让loss减小的embedding方式

你可能感兴趣的:(pytorch,pytorch,深度学习,神经网络)