torch.nn.Embedding类(nn.Module)详解(还未写完)

torch.nn.Embedding类(nn.Module)

  • torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)

    num_embeddings是输入的字典大小,如果你输入里面有10个不一样的元素,然后你的字典设置为9,会导致如下错误发生

    >>> x = torch.arange(10).reshape(5,-1)  # 输入有10个不一样的元素
    >>> embedding = torch.nn.Embedding(9, 3)  # Embedding层字典大小只有9<10导致错误
    >>> embedding(x)
    Traceback (most recent call last):
      File "D:\Anaconda\envs\first_semester_of_master\lib\site-packages\IPython\core\interactiveshell.py", line 3553, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "", line 1, in <module>
        embedding(x)
      File "D:\Anaconda\envs\first_semester_of_master\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "D:\Anaconda\envs\first_semester_of_master\lib\site-packages\torch\nn\modules\sparse.py", line 160, in forward
        self.norm_type, self.scale_grad_by_freq, self.sparse)
      File "D:\Anaconda\envs\first_semester_of_master\lib\site-packages\torch\nn\functional.py", line 2199, in embedding
        return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    IndexError: index out of range in self
    

    embedding_dim输出的每个词向量的维度

    >>> # 接上面
    >>> embedding = torch.nn.Embedding(10, 3)
    >>> embedding(x).shape
    torch.Size([5, 2, 3])
    

    embedding = torch.nn.Embedding(10, 3)的内部权重的结构是 ( 10 × 3 ) (10\times3) (10×3)

    即针对每个x的值如1,寻找权重中的weight[1,:]使用这个做结果第三维的值

    例子:

    >>> em=torch.rand(3,2)
    >>> em
    Out: 
    tensor([[0.8452, 0.5519],  # 0对应向量
            [0.7074, 0.3604],  # 1对应向量
            [0.8119, 0.3996]]) # 2对象向量
    >>> input=torch.tensor([0,1,2,1])
    >>> torch.nn.functional.embedding(input,em)
    Out: 
    tensor([[0.8452, 0.5519],  # 0
            [0.7074, 0.3604],  # 1
            [0.8119, 0.3996],  # 2
            [0.7074, 0.3604]])  # 1
    
  • 后面的参数用到再说

你可能感兴趣的:(pytorch,深度学习,pytorch,python)