import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pydot
input_array = np.random.randint(1000, size=(4, 10))
print("batch语句数量?: 4,语句长度:10")
print("相当于有4句长度为10的话")
input_array
batch语句数量?: 4,语句长度:10
相当于有4句长度为10的话
array([[426, 673, 370, 70, 881, 486, 141, 629, 355, 839],
[604, 968, 109, 31, 932, 944, 612, 527, 516, 887],
[684, 235, 911, 132, 574, 916, 744, 542, 679, 929],
[110, 751, 29, 721, 152, 929, 435, 483, 504, 194]])
model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
print("1000是文本中不同词汇的词汇量,64是需要embedding的文本量——也就是用64纬的向量来表示文本")
print("10表示输入的文本长度是10")
print("")
model.summary()
1000是文本中不同词汇的词汇量,64是需要embedding的文本量——也就是用64纬的向量来表示文本
10表示输入的文本长度是10
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 10, 64) 64000
=================================================================
Total params: 64,000
Trainable params: 64,000
Non-trainable params: 0
_________________________________________________________________
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
print(output_array.shape)
print("4句话,每句话长度10个单词,每个单词对应embedding的64纬的向量")
(4, 10, 64)
4句话,每句话长度10个单词,每个单词对应embedding的64纬的向量
output_array
array([[[-0.02118476, -0.03898944, 0.01221662, ..., 0.03329719,
0.02157934, -0.04377121],
[-0.03383125, -0.001796 , 0.01168343, ..., 0.02926688,
0.04473437, -0.04309578],
[ 0.03530608, 0.02963002, -0.01404522, ..., -0.01992302,
0.04386846, 0.03756655],
...,
[-0.03670976, 0.03208424, 0.01248705, ..., -0.00762266,
0.04845474, -0.04646262],
[ 0.01602076, -0.02721908, 0.04094582, ..., 0.02935955,
0.0319365 , -0.03111139],
[ 0.01702071, -0.03066396, 0.02680251, ..., 0.03202749,
0.00530231, -0.02432451]],
[[-0.01825579, -0.00619879, -0.00451747, ..., -0.00568254,
0.00859107, 0.02297581],
[-0.00014025, 0.00345597, -0.02279505, ..., 0.03197813,
0.0268837 , -0.02906168],
[ 0.01594374, 0.04538775, -0.01887939, ..., -0.01630086,
-0.01572884, -0.02938943],
...,
[ 0.01609715, -0.03346112, 0.04200361, ..., 0.04576402,
0.02138535, 0.0185445 ],
[ 0.01425153, 0.03494593, -0.03951036, ..., 0.00054593,
0.03156869, -0.04687175],
[ 0.00589864, -0.03460861, 0.01454576, ..., -0.03167019,
0.01469716, -0.03094819]],
[[-0.01308862, -0.04916516, 0.02899685, ..., 0.04728491,
-0.00486624, 0.01209624],
[ 0.01172228, 0.0369903 , 0.04188201, ..., 0.00603905,
-0.04522462, 0.04354359],
[ 0.00095894, -0.02864948, 0.02627457, ..., -0.01084584,
0.04538004, -0.03449055],
...,
[ 0.01551476, -0.01011632, -0.00195837, ..., 0.0099923 ,
0.01555574, -0.00624264],
[ 0.04337336, -0.01898077, 0.04472036, ..., -0.02179148,
0.04134902, -0.04898594],
[ 0.01520981, 0.03971297, 0.01663113, ..., -0.00667095,
-0.04221702, -0.01182438]],
[[ 0.00828315, -0.00025446, -0.03170226, ..., 0.03994179,
-0.00012902, -0.02713431],
[ 0.02764615, 0.02142035, 0.01596281, ..., -0.03496212,
0.02562487, 0.03629917],
[-0.01850617, -0.02221829, 0.00282773, ..., 0.02816875,
-0.04975935, 0.03505694],
...,
[ 0.0262561 , -0.02659642, 0.00986709, ..., 0.01436288,
0.03989737, -0.03775549],
[-0.0483313 , 0.04767383, -0.00810294, ..., -0.01985493,
-0.01641266, -0.02623077],
[ 0.0204062 , -0.02863589, 0.00803487, ..., -0.00382262,
0.04217901, -0.02859993]]], dtype=float32)