Transformer变形金刚学习之路(一)--Embedding

import math

import torch
from torch.autograd import Variable
import torch.nn as nn


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        """
        :param d_model: dim of word_embedding 词嵌入维度
        :param vocab: size of word_dictionary 词典大小
        """
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        # Scale out 缩放输出
        return self.lut(x) * math.sqrt(self.d_model)


if __name__ == '__main__':
    d_model = 512
    vocab = 1000
    x = Variable(torch.LongTensor([[100, 2, 412, 508],
                                   [491, 998, 1, 221]]))
    emb = Embeddings(d_model, vocab)
    embr = emb(x)
    print("embr", embr)

你可能感兴趣的:(transformer,transformer,学习,深度学习)