torch.nn.Embedding参数详解-之-num_embeddings,embedding_dim

torch.nn.Embedding()

1、关于torch.nn.Embedding()前两个参数的详解。

先贴函数全貌

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None

2、详解常用的前两个参数。常用的调用形式:nn.Embedding(10, 3)

平时调用的一般形式如下(官网的例子):

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

2.1、第一个参数的来源:

其中nn.Embedding(10,3)中的10:10=maximum index + 1即input的最大值加1,上面input最大为9,所以这里的第一个参数为10.
实验一下将input的最大值改为10,结果报错。
torch.nn.Embedding参数详解-之-num_embeddings,embedding_dim_第1张图片

2.2、第二个参数的来源:

其中nn.Embedding(10,3)中的3,表示是我们指定的nn.Embedding()输出的结果中每个向量(最里面的[])包含3个元素。如下图所示:
torch.nn.Embedding参数详解-之-num_embeddings,embedding_dim_第2张图片

3、总结:

所以torch.nn.Embedding(num_embeddings, num_embeddings,省略)这两个必须要填的参数,num_embeddings是需要进行embedding的数据决定的,num_embeddings是我们自己决定的。

你可能感兴趣的:(pytorch,python,开发语言)