keras 的 example 文件 addition_rnn.py 解析

该代码实现了通过神经网络来计算两个三位数的相加

先生成一堆训练数据,打印一下

print(questions[:10])
print(expected[:10])

结果为:

[' 31+991', ' 46+154', '    0+2', '    9+9', '    1+7', '  827+2', '  97+09', '    0+8', '    5+3', '  5+239']
['212 ', '515 ', '2   ', '18  ', '8   ', '730 ', '169 ', '8   ', '8   ', '937 ']

编码的时候,questions是前面加空格,后面是真实的计算字符串,也就是右对齐

expected是后面加空格,也就是说expected字符串是左对齐

然后进行编码,参考下面的questions编码方式

 31+991
[[ True False False False False False False False False False False False]
 [False False False False False  True False False False False False False]
 [False False False  True False False False False False False False False]
 [False  True False False False False False False False False False False]
 [False False False False False False False False False False False  True]
 [False False False False False False False False False False False  True]
 [False False False  True False False False False False False False False]]
 46+154
[[ True False False False False False False False False False False False]
 [False False False False False False  True False False False False False]
 [False False False False False False False False  True False False False]
 [False  True False False False False False False False False False False]
 [False False False  True False False False False False False False False]
 [False False False False False False False  True False False False False]
 [False False False False False False  True False False False False False]]
    0+2
[[ True False False False False False False False False False False False]
 [ True False False False False False False False False False False False]
 [ True False False False False False False False False False False False]
 [ True False False False False False False False False False False False]
 [False False  True False False False False False False False False False]
 [False  True False False False False False False False False False False]
 [False False False False  True False False False False False False False]]

上面的一行,分别对应[空格, +, 0,1,2,3,4,5,6,7,8,9],所以字符串进行了类似的one-hot编码

expected也是一样:

212
[[False False False False  True False False False False False False False]
 [False False False  True False False False False False False False False]
 [False False False False  True False False False False False False False]
 [ True False False False False False False False False False False False]]
515
[[False False False False False False False  True False False False False]
 [False False False  True False False False False False False False False]
 [False False False False False False False  True False False False False]
 [ True False False False False False False False False False False False]]
2
[[False False False False  True False False False False False False False]
 [ True False False False False False False False False False False False]
 [ True False False False False False False False False False False False]
 [ True False False False False False False False False False False False]]

因为expected中没有加号,所以第二列永远为False

 

x_train.shape和y_train.shape分别为(45000, 7, 12) (45000, 4, 12)

神经网络模型为:

__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #
==========================================================================================
lstm_1 (LSTM)                           (None, 128)                         72192
__________________________________________________________________________________________
repeat_vector_1 (RepeatVector)          (None, 4, 128)                      0
__________________________________________________________________________________________
lstm_2 (LSTM)                           (None, 4, 128)                      131584
__________________________________________________________________________________________
time_distributed_1 (TimeDistributed)    (None, 4, 12)                       1548
==========================================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
__________________________________________________________________________________________

上面可以看到,两个LSTM的输出shape不一样,一个是(None, 128),另一个是(None, 4, 128),这是因为第一个RNN的return_sequences为False,而第一个RNN的return_sequences为True

代码解释参考官方教程:

https://keras.io/zh/examples/addition_rnn/

 

——————————————————————

总目录

keras的example文件解析

你可能感兴趣的:(TensorFlow,python)