LSTM模型分析及对时序数据预测的具体实现(python实现)

这篇博客衔接上一篇博客: Holt-Winters模型原理分析及代码实现(python),我们在三次指数平滑的基础上,来进一步讨论下对时序数据的预测。

LSTM原理分析(参考博文:Understanding LSTM Networks)

Long Short Term 网络–LSTM,是神经网络的一种简单延伸,也是一种特殊的RNN模型。可以用来学习长期依赖的信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。LSTM在语言模型,图像捕捉等领域有着极其广泛的应用。 
所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第1张图片 
LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于 单一神经网络层,这里是有四个,以一种非常特殊的方式进行交互。如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第2张图片 
那么我们可以知道LSTM实现长期记忆的必要条件如下:

增加遗忘机制。例如当一个场景结束是,模型应该重置场景的相关信息,例如位置、时间等。而一个角色死亡,模型也应该记住这一点。所以,我们希望模型学会一个独立的忘记/记忆机制,当有新的输入时,模型应该知道哪些信息应该丢掉。 
如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第3张图片 
增加保存机制。当模型看到一副新图的时候,需要学会其中是否有值得使用和保存的信息。 
如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第4张图片 
所以当有一个新的输入时,模型首先忘掉哪些用不上的长期记忆信息,然后学习新输入有什么值得使用的信息,然后存入长期记忆中。 
如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第5张图片 
把长期记忆聚焦到工作记忆中。最后,模型需要学会长期记忆的哪些部分立即能派上用场。不要一直使用完整的长期记忆,而要知道哪些部分是重点。 
如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第6张图片

python实现

环境

Python 3.6 
TensorFlow 
Numpy 
Keras 
Matplotlib

构造数据

构造一个-50~50步长为1,大小为sinx的数据序列,代码如下图所示:

x = np.arange(-50.0, 50.0, 1)
y1 = np.sin(x)
y1 = np.array(y1)
plt.plot(x, y1,'ko--')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5

如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第7张图片

构造lstm需要的数据格式

Keras LSTM层的工作方式是通过接收3维(N,W,F)的数字阵列,其中N是训练序列的数目,W是序列长度,F是每个序列的特征数目。我使用了[1,5,20,1]的网络结构,其中我们有1个输入层(由大小为50的序列组成),该输入层喂食5个神经元给LSTM层,接着该LSTM层喂食20个神经元给另一个LSTM层,然后使用一个线性激活函数来喂食一个完全连接的正常层以用于下一个时间步的预测。

x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))  
.
.
.
model = lstm.build_model([1, 5, 20, 1])
model.fit(
            x_train,
            y_train,
            batch_size=512,
            nb_epoch=epochs,
            validation_split=0.05)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

预测及结果

predicted = lstm.predict_point_by_point(model, x_test)
plot_results(predicted, y_test)
  • 1
  • 2

如下图所示: 
LSTM模型分析及对时序数据预测的具体实现(python实现)_第8张图片

你可能感兴趣的:(深度学习)