nn.Embedding()的原理

nn.Embedding()的原理:

定义一个Embedding:

embeddings = nn.Embedding(num_embeddings=10, embedding_dim=3)

vocab_size : 10

输出维度为: 3

假定输入inputs如下:

inputs = torch.tensor([
    [1,3,6, 8],
    [9,1,3,5]
],dtype=torch.long)

max_num 为 9

vocab_size: 10

下面讲述他的原理:

首先假定inputs.shape = (batch_size , sentence_len)

即(2,4)

现在我们用

first = F.one_hot(inputs,num_classes=10)

去做第一次one_hot的输出,即shape = (2,4,10)

10 代表  v

即shape = (2,4 , v)

embeddings的weight.shape = (10,3)  =>>> (v,s)

那么怎么得到(2,4,s)呢?

torch.matmul(torch.tensor(first, dtype=torch.float),embeddings.weight)

即可得到embeddings(inputs)相同的结果!!!

下面为代码:

nn.Embedding()的原理_第1张图片

你可能感兴趣的:(python,人工智能,embedding,深度学习,人工智能)