浅析tf.keras.layers.Embedding

浅析tf.keras.layers.Embedding

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pydot
input_array = np.random.randint(1000, size=(4, 10))
print("batch语句数量?: 4,语句长度:10")
print("相当于有4句长度为10的话")
input_array
batch语句数量?: 4,语句长度:10
相当于有4句长度为10的话

array([[426, 673, 370, 70, 881, 486, 141, 629, 355, 839],
[604, 968, 109, 31, 932, 944, 612, 527, 516, 887],
[684, 235, 911, 132, 574, 916, 744, 542, 679, 929],
[110, 751, 29, 721, 152, 929, 435, 483, 504, 194]])

model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
print("1000是文本中不同词汇的词汇量,64是需要embedding的文本量——也就是用64纬的向量来表示文本")
print("10表示输入的文本长度是10")
print("")
model.summary()
1000是文本中不同词汇的词汇量,64是需要embedding的文本量——也就是用64纬的向量来表示文本
10表示输入的文本长度是10

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 10, 64)            64000     
=================================================================
Total params: 64,000
Trainable params: 64,000
Non-trainable params: 0
_________________________________________________________________
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
print(output_array.shape)
print("4句话,每句话长度10个单词,每个单词对应embedding的64纬的向量")
(4, 10, 64)
4句话,每句话长度10个单词,每个单词对应embedding的64纬的向量
output_array
array([[[-0.02118476, -0.03898944,  0.01221662, ...,  0.03329719,
          0.02157934, -0.04377121],
        [-0.03383125, -0.001796  ,  0.01168343, ...,  0.02926688,
          0.04473437, -0.04309578],
        [ 0.03530608,  0.02963002, -0.01404522, ..., -0.01992302,
          0.04386846,  0.03756655],
        ...,
        [-0.03670976,  0.03208424,  0.01248705, ..., -0.00762266,
          0.04845474, -0.04646262],
        [ 0.01602076, -0.02721908,  0.04094582, ...,  0.02935955,
          0.0319365 , -0.03111139],
        [ 0.01702071, -0.03066396,  0.02680251, ...,  0.03202749,
          0.00530231, -0.02432451]],

       [[-0.01825579, -0.00619879, -0.00451747, ..., -0.00568254,
          0.00859107,  0.02297581],
        [-0.00014025,  0.00345597, -0.02279505, ...,  0.03197813,
          0.0268837 , -0.02906168],
        [ 0.01594374,  0.04538775, -0.01887939, ..., -0.01630086,
         -0.01572884, -0.02938943],
        ...,
        [ 0.01609715, -0.03346112,  0.04200361, ...,  0.04576402,
          0.02138535,  0.0185445 ],
        [ 0.01425153,  0.03494593, -0.03951036, ...,  0.00054593,
          0.03156869, -0.04687175],
        [ 0.00589864, -0.03460861,  0.01454576, ..., -0.03167019,
          0.01469716, -0.03094819]],

       [[-0.01308862, -0.04916516,  0.02899685, ...,  0.04728491,
         -0.00486624,  0.01209624],
        [ 0.01172228,  0.0369903 ,  0.04188201, ...,  0.00603905,
         -0.04522462,  0.04354359],
        [ 0.00095894, -0.02864948,  0.02627457, ..., -0.01084584,
          0.04538004, -0.03449055],
        ...,
        [ 0.01551476, -0.01011632, -0.00195837, ...,  0.0099923 ,
          0.01555574, -0.00624264],
        [ 0.04337336, -0.01898077,  0.04472036, ..., -0.02179148,
          0.04134902, -0.04898594],
        [ 0.01520981,  0.03971297,  0.01663113, ..., -0.00667095,
         -0.04221702, -0.01182438]],

       [[ 0.00828315, -0.00025446, -0.03170226, ...,  0.03994179,
         -0.00012902, -0.02713431],
        [ 0.02764615,  0.02142035,  0.01596281, ..., -0.03496212,
          0.02562487,  0.03629917],
        [-0.01850617, -0.02221829,  0.00282773, ...,  0.02816875,
         -0.04975935,  0.03505694],
        ...,
        [ 0.0262561 , -0.02659642,  0.00986709, ...,  0.01436288,
          0.03989737, -0.03775549],
        [-0.0483313 ,  0.04767383, -0.00810294, ..., -0.01985493,
         -0.01641266, -0.02623077],
        [ 0.0204062 , -0.02863589,  0.00803487, ..., -0.00382262,
          0.04217901, -0.02859993]]], dtype=float32)

你可能感兴趣的:(TF2.0)