该代码实现了通过神经网络来计算两个三位数的相加
先生成一堆训练数据,打印一下
print(questions[:10])
print(expected[:10])
结果为:
[' 31+991', ' 46+154', ' 0+2', ' 9+9', ' 1+7', ' 827+2', ' 97+09', ' 0+8', ' 5+3', ' 5+239']
['212 ', '515 ', '2 ', '18 ', '8 ', '730 ', '169 ', '8 ', '8 ', '937 ']
编码的时候,questions是前面加空格,后面是真实的计算字符串,也就是右对齐
expected是后面加空格,也就是说expected字符串是左对齐
然后进行编码,参考下面的questions编码方式
31+991
[[ True False False False False False False False False False False False]
[False False False False False True False False False False False False]
[False False False True False False False False False False False False]
[False True False False False False False False False False False False]
[False False False False False False False False False False False True]
[False False False False False False False False False False False True]
[False False False True False False False False False False False False]]
46+154
[[ True False False False False False False False False False False False]
[False False False False False False True False False False False False]
[False False False False False False False False True False False False]
[False True False False False False False False False False False False]
[False False False True False False False False False False False False]
[False False False False False False False True False False False False]
[False False False False False False True False False False False False]]
0+2
[[ True False False False False False False False False False False False]
[ True False False False False False False False False False False False]
[ True False False False False False False False False False False False]
[ True False False False False False False False False False False False]
[False False True False False False False False False False False False]
[False True False False False False False False False False False False]
[False False False False True False False False False False False False]]
上面的一行,分别对应[空格, +, 0,1,2,3,4,5,6,7,8,9],所以字符串进行了类似的one-hot编码
expected也是一样:
212
[[False False False False True False False False False False False False]
[False False False True False False False False False False False False]
[False False False False True False False False False False False False]
[ True False False False False False False False False False False False]]
515
[[False False False False False False False True False False False False]
[False False False True False False False False False False False False]
[False False False False False False False True False False False False]
[ True False False False False False False False False False False False]]
2
[[False False False False True False False False False False False False]
[ True False False False False False False False False False False False]
[ True False False False False False False False False False False False]
[ True False False False False False False False False False False False]]
因为expected中没有加号,所以第二列永远为False
x_train.shape和y_train.shape分别为(45000, 7, 12) (45000, 4, 12)
神经网络模型为:
__________________________________________________________________________________________
Layer (type) Output Shape Param #
==========================================================================================
lstm_1 (LSTM) (None, 128) 72192
__________________________________________________________________________________________
repeat_vector_1 (RepeatVector) (None, 4, 128) 0
__________________________________________________________________________________________
lstm_2 (LSTM) (None, 4, 128) 131584
__________________________________________________________________________________________
time_distributed_1 (TimeDistributed) (None, 4, 12) 1548
==========================================================================================
Total params: 205,324
Trainable params: 205,324
Non-trainable params: 0
__________________________________________________________________________________________
上面可以看到,两个LSTM的输出shape不一样,一个是(None, 128),另一个是(None, 4, 128),这是因为第一个RNN的return_sequences为False,而第一个RNN的return_sequences为True
代码解释参考官方教程:
https://keras.io/zh/examples/addition_rnn/
——————————————————————
总目录