【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第1张图片

一、本文介绍

大家好,最近在搞论文所以在研究各种论文的思想,这篇文章给大家带来的是TiDE模型由Goggle在2023.8年发布,其主要的核心思想是:基于多层感知机(MLP)构建的编码器-解码器架构,核心创新在于它结合了线性模型的简洁性和速度优势,同时能有效处理协变量和非线性依赖。论文中号称TiDE在长期时间序列预测基准测试中不仅表现匹敌甚至超越了先前的方法,而且在速度上比最好的基于Transformer的模型快5到10倍。在官方的开源代码中是并没有预测未来数据功能的,因为这种都是学术文章发表论文的时候只看测试集表现。我在自己的框架下给其补上了这一功能同时加上了绘图的功能,非常适合大家发表论文的适合拿来做对比模型。

(开始之前给大家推荐一下我的专栏,本专栏包含时间序列领域各种模型适合各类人群,同时包含本人创新的框架和模型,无论你是想发论文还是工程项目中使用本专栏都能够满足你的需求) 

   专栏目录:时间序列预测目录:深度学习、机器学习、融合模型、创新模型实战案例

专栏: 时间序列预测专栏:基础知识+数据分析+机器学习+深度学习+Transformer+创新模型

预测未知数据功能如下->

下面的图片集成在模型的预测功能上,大家运行直接可以生成。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第2张图片

滚动长期预测功能如下->

下面的文件是由代码自动生成的csv文件可以用其进行各种可视化操作。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第3张图片

目录

一、本文介绍

二、TiDE的框架原理

2.1 TiDE的框架原理

2.1.1 密集多层感知机(MLP)编码器

2.1.2 密集MLP解码器

2.1.3 特征投影步骤

2.1.4 时序解码器

2.1.5 全局线性残差连接

2.2 TiDE的实验结果

三、实战所用数据集 

四、实战代码 

4.1 个人完善版本下载地址 

4.2 参数详解 

五、DiTE实战

5.1 训练模型

5.2 测试集表现 

5.3 预测未来一天的数据(结果可视化)

5.4 滚动长期预测(结果可视化 + CSV文件生成)

六、如何训练你自己的数据集

七、本文总结 


二、TiDE的框架原理

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第4张图片

论文地址: 官方论文地址

代码地址: 官方代码地址

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第5张图片


2.1 TiDE的框架原理

TiDE(时间序列密集编码器)模型是一个基于多层感知机(MLP)的编码器-解码器架构,旨在简化长期时间序列预测。该模型结合了线性模型的简单性和速度,同时能够有效处理协变量和非线性依赖。理论上,该模型的简单线性类似物能够在特定条件下为线性动态系统(LDS)实现接近最优的误差率。

TiDE(时间序列密集编码器)模型的网络结构可以概括为以下几个关键组成部分:

  1. 密集多层感知机(MLP)编码器:TiDE使用密集的MLP来编码时间序列的过去信息以及协变量。

  2. 密集MLP解码器:解码器同样基于密集的MLP,用于处理编码后的时间序列和未来的协变量。

  3. 特征投影步骤:模型在编码和解码过程中包含一个将动态协变量映射到低维空间的特征投影步骤。

  4. 时序解码器:最终的预测是通过结合每个时间步的解码向量与该时间步的投影特征来形成的。

  5. 全局线性残差连接:从回溯到预测范围,模型还增加了一个全局线性残差连接。

总结:TiDE模型的结构注重于简化和效率,避免了自注意力、递归或卷积机制,从而在处理长期时间序列预测任务时实现了线性的计算量扩展。

下面的图片是TiDE的网络结构图(附上我个人的理解)->

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第6张图片

这个网络结构从输入到输出的工作流程是:

  1. 动态协变量:它首先对时间序列的动态协变量进行特征投影,简化特征的维度。

  2. 查回和属性输入:将过去的时间序列数据(查回部分)与相关属性(如日期、假日等)结合。

  3. 编码:这些结合后的数据被送入密集编码器,它使用多层感知机对信息进行编码,生成一个内部表示。

  4. 解码:内部表示随后被送入密集解码器,再次使用多层感知机,解码预测未来的时间序列。

  5. 时序解码:每个时间步的预测通过时序解码器进行优化,以生成最终的时间序列预测。

  6. 残差连接:为了加强模型的预测能力并减少训练中的问题,如梯度消失,一个残差连接直接将输入的查回部分连接到输出端。

这整个过程就像是将时间序列的历史数据和相关信息通过一个多层处理过程,最终转化为对未来的精准预测。

下面我来分别介绍TiDE的几个关键组成部分(同时每一个步骤在TiDE的网络结构中位置我都进行了标注)->


2.1.1 密集多层感知机(MLP)编码器

密集多层感知机(MLP)编码器是TiDE模型中的核心部分,它的作用是将时间序列的历史数据(也称为查回部分)和协变量(如日期、天气等可能影响预测的外部信息)转换成内部表示。这个编码器通过一系列层次化的网络层(即MLP层)来处理输入数据,每一层都会对数据进行转换和学习,从而捕捉时间序列的复杂模式和依赖关系。简而言之,这个编码器将原始输入转换成模型可以进一步处理的压缩信息。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第7张图片


2.1.2 密集MLP解码器

密集MLP解码器在TiDE模型中扮演着将编码后的信息转换回时间序列预测的角色。这个解码器接收来自编码器的内部表示,并开始构建对未来时间点的预测。它也使用多层感知机层,这些层专门训练用于识别编码信息中的模式,并将这些模式映射到未来的协变量上,从而生成对未来时间序列的预测值。简言之,解码器的任务是解读编码的数据,并将其转换为具体的未来预测。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第8张图片


2.1.3 特征投影步骤

特征投影步骤是TiDE模型的一部分,它负责处理动态协变量——这些是随时间变化而变化的输入变量,如天气或节假日等。在这一步骤中,模型将这些协变量从高维空间降维到一个低维空间。这个过程有助于简化模型处理的信息量,减少计算复杂性,并且可能还能帮助模型更有效地从这些协变量中提取有用的特征,以便进行准确的时间序列预测。简单来说,特征投影就像是对输入数据进行压缩,以便于编码器和解码器更有效地处理。 

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第9张图片


2.1.4 时序解码器

时序解码器是TiDE模型中的一个重要组件,用于处理密集解码器输出的信息。它特别针对每个时间步骤进行工作,把解码器生成的预测转化为最终的时间序列输出。时序解码器通过在每个时间点上应用特定的变换,优化了预测的时间依赖性,增强了模型对时间序列数据中时间动态的捕捉能力。简而言之,时序解码器将解码过程与时间维度相结合,生成精确的逐步预测。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第10张图片


2.1.5 全局线性残差连接

全局线性残差连接是一种在神经网络中常见的技术,用于改善深层网络的学习效率和减少训练难度。在TiDE模型中,残差连接允许从网络的早期层(在本例中是查回部分)直接传递信息到后面的层,这有助于模型在学习过程中保留原始输入数据的信息。这样,即使在网络很深的情况下,也可以缓解梯度消失的问题,确保网络能够有效地学习和适应训练数据。简单来说,它就像是一个快捷通道,使得输入数据可以绕过多个中间层直接影响输出。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第11张图片


2.2 TiDE的实验结果

下面这个表格总结TiDE的效果对比,可以看出在绝大多数的数据上其都有一个最好的结果(但是真假性不易得知,我个人实验效果只能说和之前的一些模型持平吧)。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第12张图片

这个直方图代表了DiTE模型的训练速度(单位是秒S),我个人训练在3070上速度还可以吧毕竟不像图像领域。【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第13张图片


三、实战所用数据集 

了解时序领域的读者都知道,最出名的数据集就是ETTh系列的数据集,我使用的就是ETTh1.csv文件。 

本文我们用到的数据集是ETTh1.csv,该数据集是一个用于时间序列预测的电力负荷数据集,它是 ETTh 数据集系列中的一个。ETTh 数据集系列通常用于测试和评估时间序列预测模型。以下是 ETTh1.csv 数据集的一些内容:

数据内容:该数据集通常包含有关电力系统的多种变量,如电力负荷、价格、天气情况等。这些变量可以用于预测未来的电力需求或价格。

时间范围和分辨率:数据通常按小时或天记录,涵盖了数月或数年的时间跨度。具体的时间范围和分辨率可能会根据数据集的版本而异。 

以下是该数据集的部分截图->


四、实战代码 

4.1 个人完善版本下载地址 

前面的代码我提供了一个官方版本,但是那个很简陋,所以我将其集成了在我的框架下,同时上传到了CSDN中地址如下->

下载地址:【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)


4.2 参数详解 

其中主要的配置文件是main.py文件(我们只需要配置好该文件就可以运行该代码),这里把我设置贴出来方便大家进行设置,同时其中的参数我后面也会进行讲解。

import argparse

import torch

from exp.exp_informer import Exp_Informer

parser = argparse.ArgumentParser(description='TiDE Long Sequences Forecasting')

parser.add_argument('--model', type=str, default='TiDE',
                    help='model of experiment, options: [TiDE]')
parser.add_argument('--data', type=str, default='custom', help='data')
parser.add_argument('--root_path', type=str, default='./', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--is_rolling_predict', type=bool, default=False, help='rolling predict')
parser.add_argument('--rolling_data_path', type=str, default='ETTh1-Test.csv', help='data file')
parser.add_argument('--features', type=str, default='M',
                    help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
                    help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
parser.add_argument('--seq_len', type=int, default=126, help='input sequence length of Informer encoder')
parser.add_argument('--label_len', type=int, default=64, help='start token length of Informer decoder')
parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')

# Informer decoder input: concat[start token series(label_len), zero padding series(pred_len)]
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--s_layers', type=str, default='3,2,1', help='num of stack encoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
parser.add_argument('--factor', type=int, default=5, help='probsparse attn factor')
parser.add_argument('--padding', type=int, default=0, help='padding type')
parser.add_argument('--distil', action='store_false',
                    help='whether to use distilling in encoder, using this argument means not using distilling',
                    default=True)

parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
parser.add_argument('--attn', type=str, default='prob', help='attention used in encoder, optio---ns:[prob, full]')
parser.add_argument('--embed', type=str, default='timeF',
                    help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--do_predict', action='store_true', default=True, help='whether to predict unseen future data')
parser.add_argument('--mix', action='store_false', help='use mix attention in generative decoder', default=True)
parser.add_argument('--cols', type=str, nargs='+', help='certain cols from the data files as the input features')
parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=20, help='train epochs')
parser.add_argument('--batch_size', type=int, default=16, help='batch size of train input data')
parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)

parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')

args = parser.parse_args()
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

args.task_name = 'long_term_forecast'
if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ', '')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]

data_parser = {
    'ETTh1': {'data': 'sum.csv', 'T': 'sl', 'B': [7, 7, 7], 'S': [350, 168, 4], 'MS': [7, 7, 1]},
    'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTm1': {'data': 'sum.csv', 'T': 'sl', 'B': [7, 7, 7], 'S': [126, 42, 4], 'MS': [7, 7, 1]},
    'ETTm2': {'data': 'ETTm2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'WTH': {'data': 'WTH.csv', 'T': 'WetBulbCelsius', 'M': [12, 12, 12], 'S': [1, 1, 1], 'MS': [12, 12, 1]},
    'ECL': {'data': 'ECL.csv', 'T': 'MT_320', 'M': [321, 321, 321], 'S': [1, 1, 1], 'MS': [321, 321, 1]},
    'Solar': {'data': 'solar_AL.csv', 'T': 'POWER_136', 'M': [137, 137, 137], 'S': [1, 1, 1], 'MS': [137, 137, 1]},
    'custom': {'data': '{}'.format(args.data_path), 'T': '{}'.format(args.target), '{}'.format(args.features):
        [args.enc_in, args.dec_in, args.c_out]},
}


if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]

args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ', '').split(',')]
args.detail_freq = args.freq
args.freq = args.freq[-1:]

print('Args in experiment:')
print(args)

Exp = Exp_Informer

for ii in range(args.itr):
    # setting record of experiments
    setting = 'group_id{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(
        args.data_path, args.model, args.data, args.features,
        args.seq_len, args.label_len, args.pred_len,
        args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor,
        args.embed, args.distil, args.mix, args.des, ii)

    exp = Exp(args)  # set experiments
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    exp.train(setting)

    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    exp.test(setting)

    if args.do_predict:
        print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
        exp.predict(args, setting, True)

    torch.cuda.empty_cache()

其中涉及到很多的参数,我来进行讲解(大家如果想运行看完参数讲解你是一定能够运行的)

其中的参数涉及到运行成功与否的我都用颜色标出来了 

参数名称 参数类型 参数讲解
0 model str 这是一个用于实验的参数设置,其中包含了三个选项: informer, informerstack, informerlight。根据实验需求,可以选择其中之一来进行实验,默认是使用informer模型。
1 data str 数据,这个并不是你理解的你的数据集文件,而是你想要用官方定义的方法还是你自己的数据集进行定义数据加载器,如果是自己的数据集就输入custom
2 root_path str 这个才是你文件的路径,不要到具体的文件,到目录级别即可。
3 data_path str 这个填写你文件的名称。
4 features str 这个是特征有三个选项M,MS,S。分别是多元预测多元,多元预测单元,单元预测单元。
5 target str 这个是你数据集中你想要预测那一列数据,假设我预测的特征是OT列就输入OT即可。
6 freq str 时间的间隔,你数据集每一条数据之间的时间间隔。
7 checkpoints str 训练出来的模型保存路径
8 seq_len int 用过去的多少条数据来预测未来的数据
9 label_len int 可以裂解为更高的权重占比的部分要小于seq_len
10 pred_len int 预测未来多少个时间点的数据
11 enc_in int 你数据有多少列,要减去时间那一列,这里我是输入8列数据但是有一列是时间所以就填写7
12 dec_in int 同上
13 c_out int 这里有一些不同如果你的features填写的是M那么和上面就一样,如果填写的MS那么这里要输入1因为你的输出只有一列数据。
14 d_model int 用于设置模型的维度,默认值为512。可以根据需要调整该参数的数值来改变模型的维度
15 n_heads int 用于设置模型中的注意力头数。默认值为8,表示模型会使用8个注意力头,我建议和的输入数据的总体保持一致,列如我输入的是8列数据不用刨去时间的那一列就输入8即可。
16 e_layers int 用于设置编码器的层数
17 d_layers int 用于设置解码器的层数
18 s_layers str 用于设置堆叠编码器的层数
19 d_ff int 模型中全连接网络(FCN)的维度,默认值为2048
20 factor int  ProbSparse自注意力中的因子,默认值为5
21 padding int 填充类型,默认值为0,这个应该大家都理解,如果不够数据就填写0.
22 distil bool 是否在编码器中使用蒸馏操作。使用--distil参数表示不使用蒸馏操作,默认为True也是我们的论文中比较重要的一个改进。
23 dropout float 这个应该都理解不说了,丢弃的概率,防止过拟合的。
24 attn str 编码器中使用的注意力类型,默认为"prob"我们论文的主要改进点,提出的注意力机制。
25 embed str 时间特征的编码方式,默认为"timeF"
26 activation str 激活函数
27 output_attention bool 是否在编码器中输出注意力,默认为False
28 do_predict bool 是否进行预测,这里模型中没有给添加算是一个小bug我们需要填写一个default=True在其中。
29 mix bool 在生成式解码器中是否使用混合注意力,默认为True
30 cols str 从数据文件中选择特定的列作为输入特征,应该用不到
31 num_workers int 线程windows大家最好设置成0否则会报线程错误,linux系统随便设置。
32 itr int 实验运行的次数,默认为2,我们这里改成数字1.
33 train_epochs int 训练的次数
34 batch_size int 一次往模型力输入多少条数据
35 patience int 早停机制,如果损失多少个epochs没有改变就停止训练。
36 learning_rate float 学习率。
37 des str         实验描述,默认为"test"
38 loss str      损失函数,默认为"mse"
39 lradj str      学习率的调整方式,默认为"type1"
40 use_amp bool 混合精度训练,
41 inverse bool 我们的数据输入之前会被进行归一化处理,这里默认为False,算是一个小bug因为输出的数据模型没有给我们转化成我们的数据,我们要改成True。
42 use_gpu bool 是否使用GPU训练,根据自身来选择
43 gpu int GPU的编号
44 use_multi_gpu bool 是否使用多个GPU训练。
45 devices str GPU的编号


五、DiTE实战

5.1 训练模型

当我们配置完所有的参数之后直接运行main.py文件就可以开始训练了,控制台的输出如下-> 

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第14张图片

训练完的结果保存在该文件目录下(其中的pth文件就是我们的模型我们之后预测直接加载其就可以)-> 

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第15张图片


5.2 测试集表现 

测试集本来想做一个画图功能了后来想着大家可能也不需要,所以就没做测试集会打印出模型误差如下->


5.3 预测未来一天的数据(结果可视化)

 下面的文件是代码自动生成预测一天的csv文件。

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第16张图片

可视化如下->

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第17张图片


5.4 滚动长期预测(结果可视化 + CSV文件生成)

可以看出这个效果只能说一般,但是我们研究这种最新的模型只要看的是其思想~

【Google2023】利用TiDE进行长期预测实战(时间序列密集编码器)_第18张图片


六、如何训练你自己的数据集

上面介绍了用我的数据集训练模型,那么大家在利用模型的时候如何训练自己的数据集呢这里给家介绍一下需要修改的几处地方。

parser.add_argument('--data', type=str, default='custom', help='data')
parser.add_argument('--root_path', type=str, default='', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--features', type=str, default='MS',
                    help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',

首先需要修改的就是上面这几处,

  • 其中data必须填写custom,
  • root_path填写文件夹即可,
  • data_path填写具体的文件在你文件夹下面,
  • features前面有讲解,具体是看你自己的数据集,我这里MS就是7列结果综合分析输出想要的那一列结果的预测值,
  • target就是你数据集中你想要知道那列的预测值的列名,
  • freq就是你两条数据之间的时间间隔。
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length of Informer encoder')
parser.add_argument('--label_len', type=int, default=48, help='start token length of Informer decoder')
parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')

然后这三个就是影响精度的地方,seq_len和label_len需要根据数据的特性来设置,要进行专业的数据分析,我会在下一周出教程希望到时候能够帮助到大家。

parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
parser.add_argument('--c_out', type=int, default=7, help='output size')

这三个参数要修改和你的数据集对应和前面features的设定来配合设置,具体可以看我前面的参数讲解部分,参数需要修改的就这些,然后是代码部分如下。

data_parser = {
    'ETTh1': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTm1': {'data': 'ETTm1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'ETTm2': {'data': 'ETTm2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
    'WTH': {'data': 'WTH.csv', 'T': 'WetBulbCelsius', 'M': [12, 12, 12], 'S': [1, 1, 1], 'MS': [12, 12, 1]},
    'ECL': {'data': 'ECL.csv', 'T': 'MT_320', 'M': [321, 321, 321], 'S': [1, 1, 1], 'MS': [321, 321, 1]},
    'Solar': {'data': 'solar_AL.csv', 'T': 'POWER_136', 'M': [137, 137, 137], 'S': [1, 1, 1], 'MS': [137, 137, 1]},
    'custom': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
}

main_informer.py文件有如上的结构,这是我修改之后的,你可以按照我的修改,其中custom就是对应你前面设置参数data的名字,然后data后面替换成你的数据集,必须是csv格式的文件这里,然后是T大家不用管,OT修改成你自己数据集中预测的哪一列列名,就是前面设置的target值,然后是M,S,MS分别对应你数据中的列的给个数即可,我这里输入是8列扣去时间一列在M中就全部填写7即可,S的话我的数据集用不到,MS就是7列输出一列。 

最后呢大家如果需要我的数据集和修改完成之后的实战代码可以在评论区留言。


七、本文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的时间序列专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的模型进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~)如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

 专栏回顾: 时间序列预测专栏——持续复习各种顶会内容——科研必备

如果大家有不懂的也可以评论区留言一些报错什么的大家可以讨论讨论看到我也会给大家解答如何解决!最后希望大家工作顺利学业有成!

你可能感兴趣的:(时间序列预测专栏,人工智能,深度学习,python,时间序列预测,transformer,算法)