如何进行word embedding(tensorflow实现)

1. 什么是word embedding

  • 通过一定的方式将词汇映射到指定维度(一般是更高维度)的空间
  • 广义的word embedding包括所有密集词汇向量的表示方法,如之前学习的word2vec,即可认为是word embedding的一种
  • 狭义的word embedding是指在神经网络中加入embedding层,对整个网络进行训练时产生的embedding矩阵(embedding层的参数),这个embedding矩阵就是训练过程中所有输入词汇的向量表示组成的矩阵。

(1)从word到num

我们的自然语言,不管是中文还是英文都不能直接在机器中表达,此时就要将自然语言映射为数字。要映射成数字就要有字典,所以一般会先构建词典,举例如下:

word_dict = {"我":0, "你":1, "他":2, "她":3, "是":4, "好":5, 
             "坏":6, "人":7, "天":8, "第":9, "气":10, "今":11,
             "怎":12, "么":13, "样":14, "啊":15}

我们假设词典的大小为15,即voc_size=15,我们的sentences为"今天天气怎么样"和"他人怎么样",这样的话通过查表就可以得到如下的表示:

word2index = [[11, 8, 8, 10, 12, 13, 14],
              [2, 7, 12, 13, 14]] 

(2)pad

可以看到因为各个句子的长度不一样,所以生成的矩阵不整齐,这样也不利于进行矩阵计算,所以需要进行pad,将各个向量转为长度相同的向量。最简单的方法就是填0:

padded = [[11, 8, 8, 10, 12, 13, 14, 0, 0, 0],
          [2, 7, 12, 13, 14, 0, 0, 0, 0, 0]] 

如上所示就是将input_length设为10,如果原本长度小于10的补0,大于10的截断。

上面的matrix看起来并不是one_hot的形式,但实际上上式跟下面的表示是等价的:

[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]

(3)embedding

由于one_hot编码的稀疏性,且这种编码无法描述两个元素之间的相关性,所以可以用embedding编码。比如在上面的one_hot编码中我们是用15维的特征来描述一个字的,上面的“我”和“你”两个向量点乘的话结果为0,完全是没有关系的,但如果用另一种方式编码的话就可以有关系了:

   人称代词    名词   动词  形容词
我    0.9     0.5   0.2    0.20.8     0.6   0.3.   0.1

这样的话就可以将10维的特征描述转为4维的特征描述,且效果看起来会更好一些。

当然在进行embedding的时候不是人为设定的特征,而是人为设定好想要的特征维数之后通过语料训练得到的。

2. 实现

用的是tensorflow2.0

from tensorflow.keras.preprocessing.text import one_hot

sentences=['the glass of milk',
        'the glass of juice',
        'the cup of tea',
        'I am a good boy']

# 设置字典大小
voc_size=10000
# 这里进行one_hot映射用的是tensorflow内部给提供的映射词典
onehot_repr=[one_hot(words, voc_size) for words in sentences]
print(onehot_repr)

from tensorflow.keras.layers import Embedding
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential

import numpy as np

sent_length=8
# padding='pre'是在前面填充,padding='post'是在后面填充
embedded_docs=pad_sequences(onehot_repr, padding='pre', maxlen=sent_length) 
print(embedded_docs)

dim = 10
model = Sequential()
# 这里的dim是把每一个word映射到一个10维的向量,所以映射只有,原本(4, 8)的矩阵变成了一个(4, 8, 10)的矩阵
model.add(Embedding(voc_size, dim, input_length=sent_length))
model.compile('adam', 'mse')

print(model.summary())

#因为没有文本用来训练,所以这里的vector是随机赋的初值
vector = model.predict(embedded_docs)
print(model.predict(embedded_docs))
print(vector.shape)

结果:

>>> print(onehot_repr)
[[8174, 3076, 416, 6851], [8174, 3076, 416, 6687], [8174, 9660, 416, 8721], [5223, 4222, 5952, 6180, 5440]]

>>> print(embedded_docs)
[[   0    0    0    0 8174 3076  416 6851]
 [   0    0    0    0 8174 3076  416 6687]
 [   0    0    0    0 8174 9660  416 8721]
 [   0    0    0 5223 4222 5952 6180 5440]]

>>> print(model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
embedding (Embedding)        (None, 8, 10)             100000
=================================================================
Total params: 100,000
Trainable params: 100,000
Non-trainable params: 0
_________________________________________________________________
None

>>> print(model.predict(embedded_docs))
[[[ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [-0.02706006  0.01063529 -0.01942388  0.02701591 -0.04124977
   -0.00983888  0.01273515  0.03012211  0.04841721 -0.01894962]
  [ 0.00510359  0.01853384 -0.02409974 -0.02285388  0.04018563
   -0.04754727  0.02264073 -0.01251531 -0.04369598  0.03063634]
  [-0.03827395  0.0083343  -0.03649645  0.00391301 -0.0283778
    0.04224857  0.03885354 -0.01442292  0.01358733 -0.03044585]
  [-0.02544751 -0.02753698  0.00250997 -0.01593918  0.04284723
    0.03717153  0.01787357  0.01125566 -0.0267596  -0.0248112 ]]

 [[ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [-0.02706006  0.01063529 -0.01942388  0.02701591 -0.04124977
   -0.00983888  0.01273515  0.03012211  0.04841721 -0.01894962]
  [ 0.00510359  0.01853384 -0.02409974 -0.02285388  0.04018563
   -0.04754727  0.02264073 -0.01251531 -0.04369598  0.03063634]
  [-0.03827395  0.0083343  -0.03649645  0.00391301 -0.0283778
    0.04224857  0.03885354 -0.01442292  0.01358733 -0.03044585]
  [ 0.03599827 -0.00697263  0.01096133  0.01282989  0.04026625
   -0.0409615  -0.03822895  0.03571489 -0.03869583  0.0247351 ]]

 [[ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [-0.02706006  0.01063529 -0.01942388  0.02701591 -0.04124977
   -0.00983888  0.01273515  0.03012211  0.04841721 -0.01894962]
  [-0.04196328 -0.0178645   0.01629119  0.00710867  0.03742753
    0.04766042 -0.01144195 -0.00392986  0.04960826  0.01370332]
  [-0.03827395  0.0083343  -0.03649645  0.00391301 -0.0283778
    0.04224857  0.03885354 -0.01442292  0.01358733 -0.03044585]
  [ 0.03060566 -0.01925355 -0.01740856  0.00497576 -0.04157882
    0.01061495  0.04219753 -0.02456384  0.03463561 -0.01594185]]

 [[ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [ 0.04522029 -0.01226036  0.00501914 -0.04841211 -0.03839918
   -0.01737207  0.01646307 -0.02959521  0.03347592  0.0029294 ]
  [-0.01688796  0.02185148  0.01407048  0.01172693 -0.04144372
   -0.02081727  0.02715001  0.01198126  0.00415362  0.02064079]
  [ 0.01959873  0.03910967 -0.03127551 -0.04483137 -0.01185248
    0.03648222  0.04708296 -0.00957827 -0.002679    0.03122015]
  [ 0.00080504  0.00700544  0.02628921  0.0229356   0.04947283
    0.01667294  0.03602554 -0.01248958 -0.00070317 -0.03361555]
  [-0.0488179  -0.02457787 -0.03306667 -0.03750541 -0.03436396
   -0.04636976 -0.03443474  0.00712519 -0.02974316  0.03063191]
  [ 0.0434726   0.04021135 -0.03558815  0.04452255  0.04240603
   -0.011404    0.00316377 -0.01917359  0.03822576 -0.01635139]]]

>>> print(vector.shape)
(4, 8, 10)

你可能感兴趣的:(NLP,tensorflow,自然语言处理,机器学习)