torch.nn.Embedding的函数

保存了固定字典和大小的简单查找表。模块常用来保存词嵌入和用下标检索它们。模块的输入是一个下标的列表,输出是对应的词嵌入

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

一个矩阵类,里面初始化了一个随机矩阵,矩阵的长是字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。

word2idx={'hello':0, 'world':1}
import torch.nn as nn
embed = nn.Embedding(2, 5)
import torch
helloidx = torch.LongTensor([word2idx['hello']])
helloidx = torch.autograd.Variable(helloidx)
helloembed = embed(helloidx)

helloembed
Out[9]:
tensor([[ 2.0503, 0.1744, -0.9441, 0.9574, -0.0701]],
grad_fn=)
embed
Out[10]: Embedding(2, 5)
type(embed)
Out[11]: torch.nn.modules.sparse.Embedding
embed.weight.data
Out[14]:
tensor([[ 2.0503, 0.1744, -0.9441, 0.9574, -0.0701],
[ 1.2634, 0.7174, 0.0343, -0.1236, 0.6455]])

你可能感兴趣的:(pytorch)