pytorch embedding 理解

可见知乎

之前由于懒且不重视没思考embedding,故近期实战对其有些“误解”。害,可见这都是之前欠下的债啊,得补上!小白记录贴,仅供参考~

从pytorch源码里简单找了找,并没有找到对embedding有直观解释的代码,故转向tensorflow。

不管什么框架,原理得是一样的吧~对embedding追根溯源,发现主要包括两部分:

  1. 对input[batch_size, seq_len]进行one-hot编码[batch_size, vocab_size];
  2. 将one-hot编码后的矩阵和weight[vacab_size, embed_dim矩阵相乘;

 

复现代码如下:

利用pytorch给的接口可以得到embedding之后的值如下:

import torch.nn.functional as F
import torch

input = torch.tensor([[1,2,4,5]])
weights = torch.rand(10, 3)
F.embedding(input, weights)
'''
tensor([[[0.2776, 0.0587, 0.9897],
         [0.9066, 0.3682, 0.0840],
         [0.0370, 0.3854, 0.0091],
         [0.5261, 0.5255, 0.1317]]])
'''

按照理解改写如下,得到相同的结果:

import numpy as np
np.matmul(tf.one_hot(input,depth=10),weights)
'''
tensor([[[0.9566, 0.8623, 0.8421],
         [0.7956, 0.9499, 0.0336],
         [0.4343, 0.6607, 0.8412],
         [0.2082, 0.7314, 0.6296]]])
'''

2014年的论文text-CNN 有三种不同的embedding机制:rand/static/non static/,其中static利用训练好的word2vec向量,而non-static应该就是将embedding作为网络的一部分进行训练,这里的微调其实并不是input有所改变,而是embedding层之后的x_emb有改变!归根到底是fine-tune网络~

 

参考文献:

申小明77:tensorflow中的Embedding操作详解​zhuanlan.zhihu.com

【python】np.dot()、np.multiply()、np.matmul()方法以及*和@运算符的用法总结_敲代码的quant的博客-CSDN博客​blog.csdn.net

你可能感兴趣的:(深度学习,预处理,pytorch,自然语言处理,神经网络)