TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting

tft

  • 1 模型简述
    • 1.1 输入
    • 1.2 输出
  • 2 损失函数
  • 3 模型结构
    • 3.1 基本结构
    • 3.2 整体结构
  • 4 模型参数说明
    • 4.1 所有参数
    • 4.2 列定义参数
    • 4.3 处理后输入
  • 5 TFT贡献
  • 6 其他总结
  • 参考资料

1 模型简述

tft模型具有下面特征:

  • 支持多个时间序列
  • 基于注意力的模型结构
  • 具有可解释性
  • 特征选择,并使用门控进行特征压缩,速度快

1.1 输入

输入数据为df格式,列可分为下面六类

  • target:预测目标值
  • observed inputs:观测输入,比如上一时刻的值等,无法提前知道的
  • known inputs:已知输入,比如年月日,节假日等,可以提前知道的
  • static input:静态输入,比如商店的地址等,不会变化的
  • id:时间序列编号,不作为模型输入,只作为索引
  • time:时间索引,不作为模型输入,只作为索引

1.2 输出

各分位数的预测值:
比如:quantiles = [0.1, 0.5, 0.9]
模型就会给出0.1,0.5,0.9分位数的预测值

2 损失函数

分位数损失函数
TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第1张图片

3 模型结构

TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第2张图片
说明:

  • Variable Selection is used for judicious selection of the most salient features based on the input. ----特征选择可以从输入中选择更显著的特征
  • Gated Residual Network blocks enable efficient information flow with skip connections and gating layers. ----GRN中的跳跃连接和门控层可以使信息流通更有效率,模型训练速度更快
  • Time-dependent processing is based on LSTMs for local processing, and multi-head attention for integrating information from any time step.

3.1 基本结构

TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第3张图片

TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第4张图片TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第5张图片TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第6张图片TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第7张图片
TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第8张图片

3.2 整体结构

TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第9张图片

4 模型参数说明

4.1 所有参数

  • dropout_rate: 0.1 ----训练时样本随机丢弃率,防止过拟合
  • hidden_layer_size: 5 ----隐藏层大小
  • learning_rate: 0.001, ----学习率
  • max_gradient_norm: 0.01, ----adam中的梯度剪裁方法,clipnorm=max_gradient_norm,即剪裁后的标准化值不大于clipnorm值。
  • minibatch_size: 64, ----批次大小
  • model_folder: ./outputs/, ----模型保存路径
  • num_heads: 4, ----transformer 中注意力头的个数
  • stack_size: 1,----自注意层堆叠层数,默认时1,代码中没用到,设置了也不起作用
  • total_time_steps: 192, ---- 编码器长度+解码器长度
  • num_encoder_steps: 168, ----编码器长度
  • num_epochs: 1, ----训练轮数
  • early_stopping_patience: 10, ----早停等待次数
  • multiprocessing_workers: 5,----多进程数量
  • column_definition: [(id, 0, 4), # 列定义
    (hours_from_start, 0, 5),
    (power_usage, 0, 0),
    (hour, 0, 2),
    (day_of_week, 0, 2),
    (hours_from_start, 0, 2),
    (categorical_id, 1, 3)],
  • input_size: 5, ----输入维度
  • output_size: 1,----输出维度
  • category_counts: [369],----所有类别信息,类别序号是从0开始
  • input_obs_loc: [0], ----观察输入,观测输入不能是静态输入:static_input_loc
  • static_input_loc: [4], ----静态输入,值不改变的,category
  • known_regular_inputs: [1, 2, 3], ----已知的非类别输入,已知输入不能是观察输入:input_obs_loc
  • known_categorical_inputs: [0] ----已知的类别输入,类别序号是从0开始

4.2 列定义参数

column_definition 输入格式为(列名,数据类型, 输入类型)],

  • 其中数据类型 有三类
    0 - 实数
    1 - 类别
    2 - 日期
  • 输入类型 有六类
    0- TARGET 目标列
    1 - OBSERVED_INPUT 观测输入
    2 - KNOWN_INPUT 已知输入
    3 - STATIC_INPUT 静态输入
    4 - ID 用来识别时间序列的编号
    5-TIME 时间索引
  • input_size = 所有列 - 2 (id, time),因为id和time不作为输入,只作为索引。
  • 剩余的所有输入列可分为两类,categorical类和regular类:
  • num_categorical_variables = len(self.category_counts)
  • num_regular_variables = self.input_size - num_categorical_variables
  • regular_inputs = all_inputs[:, :, :num_regular_variables] ----所有regular都会convert_real_to_embedding
  • categorical_inputs = all_inputs[:, :, num_regular_variables:] ----所有category都会进行embedding

4.3 处理后输入

数据处理中会将所有输入分为四类:

  • obs_inputs, 观测输入
    = input_obs_loc 所有观测的值
  • static_inputs,静态输入
    = regular_inputs中 是静态的 + categorical_inputs中 是静态的
  • known_combined_layer, 已知输入
    = known_regular_inputs 已知中非静态的 + known_categorical_inputs 已知中非静态的
  • unknown_inputs, 未知输入
    = regular_inputs中 非已知 且 非观测的 + categorical_inputs 中 非已知 且 非观测的

5 TFT贡献

  • 门控机制:gate往往和add&norm一起使用。可使模型适用不同的深度和网络复杂度,从而使模型适用于不同的数据和场景。
  • 特征选择网络:可使模型在每一步选择相关的特征。
  • 静态协变量编码:将静态变量编码,作为辅助上下文信息加入到网络中,是其他特征编码的条件。
  • seq2seq层用来捕捉短时依赖关系
  • 多头注意力机制用来捕捉长时依赖关系
  • 用百分位数作为预测区间。

6 其他总结

  1. 门控结构的作用
    ----门控结构的作用:门控结构让模型只关注有用的特征,忽略那些无用的特征,会让模型能够适应不同的深度和复杂度,从而是模型适用于不同的数据集和场景。
  2. GRN有什么作用
    ----GRN有什么作用:信息流通更有效率。因为他使用了skip策略和门控机制。即能捕捉重要的特征。也不会漏掉任何重要信息。
  3. 为什么lstm捕捉的是局部信息? 为什么多头注意力模块捕捉的是全局信息?
    ----lstm是直接处理输入数据的,会处理historical_features, future_features,还要间接处理static特征。最终得到状态temporal_feature_layer。
    ----而多头注意力模块没有直接处理输入的特征数据,而是处理的temporal_feature_layer和静态特征数据。多头注意力模块的输出为decoder。还有相应的注意力self_att。
    ----输出值是如何得到的呢?输出值是将多头注意力模块的输出decoder和lstm的输出temporal_feature_layer一起做add&norm, 然后再做Dense得到的。
    ----因为lstm离输入数据近,受输入数据影响大,所以最可能捕捉的是局部信息。
    ----因为多头注意力模块处理的是lstm的输出状态,离原始输入较远,所以抽象层次更高,受原始输入影响较小,所以捕捉到的最可能是全局性的信息。
    ----所以,输出值是由lstm代表的局据信息+多头注意力模块代表的全局信息共同得到的。二者缺一不可。
    ----通过观察多头注意力模块的输出self_att,我们可以看到,确实捕捉的是类似周期性这种的全局信息。
    ----因为self_att的维度也是固定的,如果有周期性的话,那么self_att中的注意力数据应该也会呈现出周期性,因为i时刻的模式肯定与i-T时刻的模式很像,i时刻对i-T时刻的关注肯定会多一点,有点类似acf自相关函数。
    如下图所示:天周期每24h会有一个峰值,星期周期每7d会有一个峰值。
    TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting_第10张图片
  4. 特征重要性
    ----特征重要性可以通过static_weights,historical_flags,future_flags得到,他们代表特征选择权重。

参考资料

tft文章链接
google-research github实现代码

你可能感兴趣的:(算法实战,算法)