torch.nn.Embedding()
前两个参数的详解。先贴函数全貌
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None
nn.Embedding(10, 3)
平时调用的一般形式如下(官网的例子):
>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902, 0.7172],
[-0.6431, 0.0748, 0.6969],
[ 1.4970, 1.3448, -0.9685],
[-0.3677, -2.7265, -0.1685]],
[[ 1.4970, 1.3448, -0.9685],
[ 0.4362, -0.4004, 0.9400],
[-0.6431, 0.0748, 0.6969],
[ 0.9124, -2.3616, 1.1151]]])
其中nn.Embedding(10,3)中的
10:10=maximum index + 1
即input的最大值加1,上面input最大为9,所以这里的第一个参数为10.
实验一下将input的最大值改为10,结果报错。
其中nn.Embedding(10,3)中的
3,表示是我们指定的nn.Embedding()
输出的结果中每个向量(最里面的[]
)包含3个元素。如下图所示:
所以torch.nn.Embedding(num_embeddings, num_embeddings,省略)
这两个必须要填的参数,num_embeddings
是需要进行embedding的数据决定的,num_embeddings
是我们自己决定的。