LSTM 详解及 LSTM 解决时间序列预测问题(附代码)

本文章主要涉及以下工作:
   (1)详细介绍了 LSTM 的基本原理及计算过程。
   (2)利用深度学习框架 pytorch 编程,基于多层 LSTM 解决时间序列预测问题。
   (3)提供了项目的 Python 代码以及相应的使用文档。
如果文章有用,欢迎各位→点赞 + 收藏 + 留言
项目代码:Sales-Predict-With-LSTM
如果项目代码有用,请给Github项目star一下;谢谢啦

目录

    • 1. LSTM 概念
      • (1) RNN的介绍
      • (2) RNN无法处理长期依赖
      • (3) RNN与LSTM的对比
      • (4) LSTM的核心思想
      • (5) LSTM工作流程
        • a. 忘记门
        • b. 输入门
        • c. 细胞状态
        • d. 输出门
    • 2. 基于LSTM解决时间序列预测问题
      • (1) 数据集描述
      • (2) 配置文件
        • a. 数据集参数
        • b. 网络参数
        • c. 训练参数
        • d. 训练模式
        • e. 路径参数
      • (3) 运行结果展示
    • 3. 参考资料

1. LSTM 概念

(1) RNN的介绍

循环神经网络(Recurrent Neural Network,RNN)是一种人工神经网络,主要用于处理序列数据,如自然语言、音频和视频等。

与其他神经网络不同的是,RNN具有内部循环结构,可以保留先前的信息,并在后续的计算中使用它们。如下图所示,在这个循环的结构中,每个神经网络的模块 A,读取某个输入 x t x_{t} xt,并输出一个值 h t h_{t} ht,然后不断循环。循环可以使得信息可以从当前步传递到下一步。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第1张图片
RNN的每个单元都有一个状态,这个状态将被传递到下一个单元中。在每个时间步,RNN都会接收当前输入和前一时刻的状态作为输入,并输出当前时刻的状态和相应的输出。这个输出可以被用作下一个时间步的输入,形成一个循环。

由于RNN具有内部循环结构,因此它可以处理任意长度的序列数据,并且可以用于各种任务,包括语言建模、机器翻译、语音识别等。此外,RNN还可以通过反向传播算法进行训练,以优化网络的权重和偏差,从而提高模型的准确性和性能。

(2) RNN无法处理长期依赖

RNN无法处理长期依赖的问题,是由于在反向传播时,梯度会在每个时间步骤中指数级地衰减,这使得远距离的信息无法有效传递。

这个问题可以通过以下方式来理解:在RNN的训练过程中,梯度是从后向前传递的,每个时间步骤的梯度都会影响前面的时间步骤。如果一些时间步骤之间存在长期的依赖关系,那么梯度在传递时会多次相乘,导致指数级的衰减。这就使得前面的时间步骤无法有效地传递信息,从而影响了模型的准确性。

举个例子,考虑一个RNN模型用于语言建模。假设RNN预测“云在天空中”中的最后一个单词,我们不需要任何进一步的上下文——很明显下一个单词将是天空。在这种情况下,相关信息和需要信息的地方之间的差距很小,RNN可以学习使用过去的信息。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第2张图片

但假设我们要用RNN生成一段长序列的文本,比如一篇新闻文章。在这种情况下,RNN需要记住文章中的前一部分内容,并在后续的计算中使用它们。然而,如果文章中存在长期依赖关系,例如在文章的开头提到了一个事件,而这个事件在文章的结尾才有进一步的描述,那么RNN很难记住这个事件,并在后续的计算中使用它们。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第3张图片

为了解决这个问题,研究人员提出了一些改进的RNN模型,如LSTM和GRU。这些模型通过引入门控机制,可以控制信息的流动和遗忘,从而有效地解决了梯度消失的问题。这些门控机制可以控制信息的流动和遗忘,使得RNN可以记住长期的依赖关系,从而提高了模型的准确性。

总的来说,RNN无法处理长期依赖的问题是由于梯度消失的现象所导致的,这会使得远距离的信息无法有效传递。为了解决这个问题,改进的RNN模型引入了门控机制,可以控制信息的流动和遗忘,从而使得RNN可以处理长期依赖关系。

(3) RNN与LSTM的对比

长短时记忆网络(Long Short-Term Memory,LSTM)是一种改进的循环神经网络(RNN)模型,通过引入门控机制来解决RNN无法处理长期依赖关系的问题。LSTM模型由Hochreiter和Schmidhuber于1997年提出,目前已被广泛应用于自然语言处理、语音识别、图像处理等领域。

所有RNN都具有一种重复神经网络模块的链式的形式。在标准的RNN中,这个重复的模块只有一个非常简单的结构,例如一个tanh层。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第4张图片

激活函数 Tanh 作用在于帮助调节流经网络的值,使得数值始终限制在 -1 和 1 之间。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第5张图片
LSTM同样是这样的结构,但是重复的模块拥有一个不同的结构。具体来说,RNN是重复单一的神经网络层,LSTM中的重复模块则包含四个交互的层,三个Sigmoid 和一个tanh层,并以一种非常特殊的方式进行交互。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第6张图片

上图中,σ表示的Sigmoid 激活函数与 tanh 函数类似,不同之处在于 sigmoid 是把值压缩到 0 ~ 1 之间而不是 -1 ~ 1 之间。这样的设置有助于更新或忘记信息:

  • 因为任何数乘以 0 都得 0,这部分信息就会剔除掉;
  • 同样的,任何数乘以 1 都得到它本身,这部分信息就会完美地保存下来

相当于要么是1则记住,要么是0则忘掉,所以还是这个原则:因记忆能力有限,记住重要的,忘记无关紧要的。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第7张图片

对于图中使用的各种元素的图标中,每一条黑线传输着一整个向量,从一个节点的输出到其他节点的输入。粉色的圈代表pointwise的操作,诸如向量的和,而黄色的矩阵就是学习到的神经网络层。合在一起的线表示向量的连接,分开的线表示内容被复制,然后分发到不同的位置。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第8张图片

(4) LSTM的核心思想

LSTM的核心思想是引入门控机制来解决RNN无法处理长期依赖关系的问题。

门是一种选择性地让信息通过的方法。它们由sigmoid神经网络层和点乘运算组成。sigmoid层输出0到1之间的数字,描述每个组件应该通过的数量。值0表示“不让任何东西通过”,而值1表示“让所有东西通过”。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第9张图片

在传统的RNN中,信息会在时间步之间传递,但是往往会出现梯度消失或爆炸的问题,导致RNN无法处理长期依赖关系。例如,在生成文字序列时,上一个字符的信息可能需要在多个时间步中传递才能影响到后续的字符,而RNN可能会忘记这些信息。

为了解决这个问题,LSTM引入了三个门控单元:输入门、遗忘门和输出门。这些门控单元可以控制信息的流动和遗忘,从而有效地解决了梯度消失的问题。具体来说:

  • 输入门控制新输入的信息的流动。在每个时间步,LSTM模型会接收一个新的输入,并根据当前的状态和前一个时间步的输出计算一个新的候选状态。这个候选状态表示新的信息,需要通过输入门来控制它的流动。
  • 遗忘门控制旧状态的遗忘。LSTM模型需要记住之前的状态,而遗忘门可以控制哪些元素需要被遗忘,哪些元素需要保留。
  • 记忆单元可以存储和更新状态信息。在每个时间步,LSTM模型需要根据输入门、遗忘门和候选状态来更新状态信息。具体来说,记忆单元可以根据遗忘门和输入门的控制,以一定的比例遗忘旧状态和融合新状态,得到新的状态。
  • 输出门控制输出的信息。LSTM模型需要将当前的状态信息转换为一个输出,并根据输出门来控制输出的信息。

(5) LSTM工作流程

a. 忘记门

当新的输入进入LSTM网络时,忘记门会决定哪些信息应该被遗忘或保留。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第10张图片

忘记门通过一个sigmoid函数来决定哪些信息需要被遗忘。在每个时间步 t t t,忘记门的计算公式如下:

f t = σ ( W f [ x t , h t − 1 ] + b f ) f_t = \sigma(W_f[x_t, h_{t-1}] + b_f) ft=σ(Wf[xt,ht1]+bf)

其中, W f W_f Wf是忘记门的权重矩阵, b f b_f bf是偏置项, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一时间步的隐藏状态。 σ \sigma σ是sigmoid函数,将输入值映射到0到1之间的概率值。

忘记门的输出 f t f_t ft是一个0到1之间的值,表示应该遗忘多少过去的信息。当 f t f_t ft接近1时,过去的信息会被完全保留;当 f t f_t ft接近0时,过去的信息会被完全遗忘。

忘记门的具体计算过程如下动图。
LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第11张图片

b. 输入门

当新的输入进入LSTM网络时,输入门会决定哪些信息应该被保留并更新细胞状态。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第12张图片

输入门通过一个sigmoid函数来决定哪些信息需要被保留。在每个时间步 t t t,输入门的计算公式如下:

i t = σ ( W i [ x t , h t − 1 ] + b i ) i_t = \sigma(W_i[x_t, h_{t-1}] + b_i) it=σ(Wi[xt,ht1]+bi)

其中, W i W_i Wi是输入门的权重矩阵, b i b_i bi是偏置项, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一时间步的隐藏状态。 σ \sigma σ是sigmoid函数,将输入值映射到0到1之间的概率值。

输入门的输出 i t i_t it是一个0到1之间的值,表示哪些新的输入应该被保留。当 i t i_t it接近1时,所有新的输入都会被完全保留;当 i t i_t it接近0时,所有新的输入都会被完全忽略。

接下来,LSTM会计算候选细胞状态 c t ~ \tilde{c_t} ct~,它表示当前时间步的新输入可以对细胞状态产生多少影响。候选细胞状态的计算公式如下:

c t ~ = tanh ⁡ ( W c [ x t , h t − 1 ] + b c ) \tilde{c_t} = \tanh(W_c[x_t, h_{t-1}] + b_c) ct~=tanh(Wc[xt,ht1]+bc)

其中, W c W_c Wc是候选细胞状态的权重矩阵, b c b_c bc是偏置项, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一时间步的隐藏状态。 tanh ⁡ \tanh tanh是双曲正切函数,将输入值映射到-1到1之间的值。

输入门的具体计算过程如下动图。
LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第13张图片
输入门的作用是控制新的输入在当前时间步的权重。通过输入门,LSTM能够更好地处理长序列数据,避免梯度消失和梯度爆炸的问题,从而提高模型的效果和稳定性。

c. 细胞状态

细胞状态可以被看作是整个LSTM网络的核心,它可以存储和传递信息,同时也能够控制信息的流动和更新。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第14张图片

LSTM的细胞状态会被更新和传递到下一个时间步。在每个时间步 t t t,细胞状态的更新公式如下:

c t = f t ⋅ c t − 1 + i t ⋅ c t ~ c_t = f_t \cdot c_{t-1} + i_t \cdot \tilde{c_t} ct=ftct1+itct~

其中, f t f_t ft是忘记门,表示对细胞状态进行遗忘的权重; i t i_t it是输入门,表示对细胞状态进行更新的权重; c t ~ \tilde{c_t} ct~是当前时间步的候选细胞状态,表示当前时间步的新输入可以对细胞状态产生多少影响。

f t f_t ft接近1时,过去的信息会被完全保留;当 f t f_t ft接近0时,过去的信息会被完全遗忘。当 i t i_t it接近1时,新的输入会被完全保留;当 i t i_t it接近0时,新的输入会被完全忽略。

细胞状态的具体计算过程如下动图。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第15张图片

细胞状态的更新和传递是LSTM中非常重要的过程。在训练过程中,LSTM网络可以通过学习到的权重来自适应地更新细胞状态,保留和传递重要的信息。

d. 输出门

当需要将当前时间步的信息传递到下一层或输出层时,需要通过输出门来控制哪些信息应该被输出。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第16张图片

输出门通过一个sigmoid函数来决定哪些信息需要被输出。在每个时间步 t t t,输出门的计算公式如下:

o t = σ ( W o [ x t , h t − 1 ] + b o ) o_t = \sigma(W_o[x_t, h_{t-1}] + b_o) ot=σ(Wo[xt,ht1]+bo)

其中, W o W_o Wo是输出门的权重矩阵, b o b_o bo是偏置项, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一时间步的隐藏状态。 σ \sigma σ是sigmoid函数,将输入值映射到0到1之间的概率值。

输出门的输出 o t o_t ot是一个0到1之间的值,表示哪些信息应该被输出。当 o t o_t ot接近1时,所有的信息都会被完全保留;当 o t o_t ot接近0时,所有的信息都会被完全屏蔽。

接下来,LSTM会将细胞状态 c t c_t ct通过一个tanh函数进行处理,得到当前时间步的隐藏状态 h t h_t ht

h t = o t ⋅ tanh ⁡ ( c t ) h_t = o_t \cdot \tanh(c_t) ht=ottanh(ct)

其中, tanh ⁡ \tanh tanh是双曲正切函数,将输入值映射到-1到1之间的值。

输出门的具体计算过程如下动图。

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第17张图片

通过输出门,LSTM能够根据当前时间步的细胞状态和隐藏状态,自适应地控制信息的输出。这样,在LSTM网络中,重要的信息能够被自动地筛选出来,并传递到下一层或输出层。输出门的作用是控制当前时间步的隐藏状态 h t h_t ht中哪些信息应该被输出,从而提高LSTM网络的准确率和效果。

2. 基于LSTM解决时间序列预测问题

(1) 数据集描述

列名 数据格式 含义
日期 date 具体时间
浏览量 int 用户在电商平台页面上查看的次数
访客数 int 电商平台页面的访问者数量
人均浏览量 float 一天内用户平均在电商平台页面上查看的次数
平均停留时间 float 访问者浏览页面所花费的平均时长
跳失率 float 用户通过相应入口进入,只访问了一个页面就离开的访问次数占该页面总访问次数的比例
成交客户数 int 成功付款的客户数
成交单量 int 成功付款的订单数量
成交金额 int 成功付款的总金额
客单价 float 每个用户平均购买商品的金额
成交商品件数 int 成功付款的商品件数
下单客户数 int 已下订单的客户数
下单单量 int 已下订单的订单数量
下单金额 int 已下订单的总金额
下单商品件数 int 已下订单的商品件数

(2) 配置文件

a. 数据集参数
  • feature_columns : csv数据集中用作特征的列,列的编号为0,1,2,···
  • label_columns : csv数据集中用作标签的列,列的编号为0,1,2,···
  • predict_day : 预测未来多少天
b. 网络参数
  • input_size : 输入层尺寸,即用作特征的列的个数
  • output_size : 输出层尺寸,即用作标签的列的个数
  • hidden_size : 隐藏层尺寸
  • lstm_layers : LSTM的层数
  • dropout_rate : Dropout的概率
  • time_step : LSTM中的time_step,即用前多少天的数据来预测后一天
c. 训练参数
  • do_train : 是否训练模型
  • do_predict : 模型是否用作预测
  • add_train : 是否在已训练好的权重上继续训练
  • shuffle_train_data : 是否随机打乱训练数据
  • use_cuda : 是否使用GPU训练
  • train_data_rate : 训练数据占总体数据比例
  • valid_data_rate : 验证数据占训练数据比例
  • batch_size : 单次传递给模型用以训练的样本个数
  • learning_rate : 学习率
  • epoch : 模型训练次数
  • patience : 训练多少epoch,验证集没提升就停掉
  • random_seed : 随机种子,保证可复现
  • do_continue_train : 每次训练把上一次的final_state作为下一次的init_state
d. 训练模式
  • debug_mode : 调试模式下,是为了跑通代码,追求快
  • debug_num : 仅用debug_num条数据来调试
e. 路径参数
  • train_data_path : 数据集保存位置
  • model_save_path : 模型权重保存位置
  • figure_save_path : 预测结果图片保存位置
  • log_save_path : 训练记录保存位置
  • do_log_print_to_screen : 是否将config和训练过程在屏幕显示
  • do_log_save_to_file : 是否将config和训练过程记录到log
  • do_figure_save : 是否保存预测结果图片
  • do_train_visualized : 训练loss可视化,pytorch用visdom

(3) 运行结果展示

  • 下单商品件数的预测

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第18张图片

  • 成交金额的预测

LSTM 详解及 LSTM 解决时间序列预测问题(附代码)_第19张图片

3. 参考资料

[1] 如何从RNN起步,一步一步通俗理解LSTM
[2] LSTM详解
[3] LSTM 简介

你可能感兴趣的:(智能算法,lstm,深度学习,人工智能,算法,python)