运行 informer的主文件
import argparse
import os
import torch
from exp.exp_informer import Exp_Informer
parser.add_argument的这些
参数名称 | 参数描述 |
model | 实验模型。可以设置为informer、informerstack、informerlight(TBD) |
data | 数据集名称 |
root_path | 数据文件的根路径(默认为./data/ETT/) |
data_path | 数据文件名称(默认为ETTh1.csv) |
features | 预测任务(默认为M)。可以设置为M、S、MS (M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量) |
target | S或MS任务中的目标特征(默认为OT) |
freq | 时间特征编码的频率(默认为h) 可以设置为s(秒)、t(分钟)、h(小时)、d(日)、b(工作日)、w(周)、m(月)。也可以使用更详细的频率,如15min或3h |
checkpoints | 模型检查点的位置(默认为./checkpoints/) |
seq_len | Informer编码器的输入序列长度(默认为96) |
label_len | Informer解码器的起始标记长度(默认为48) |
pred_len | 预测序列长度(默认为24) |
enc_in | 编码器输入大小(默认为7) |
dec_in | 解码器输入大小(默认为7) |
c_out | 输出大小(默认为7) |
d_model | 模型的维度(默认为512) |
n_heads | 头的数量(默认为8) |
e_layers | 编码器层的数量(默认为2) |
d_layers | 解码器层的数量(默认为1) |
s_layers | 堆叠编码器层的数量(默认为3,2,1) |
d_ff | fcn的维度(默认为2048) |
factor | Probsparse attn因子(默认为5) |
padding | 填充类型(默认为0) |
distil | 是否在编码器中使用提炼,使用此参数表示不使用提炼(默认为True) |
dropout | 丢弃的概率(默认为0.05) |
attn | 编码器中使用的注意力(默认为prob)。可以设置为prob(informer)、full(transformer) |
embed | 时间特征的编码(默认为timeF)。可以设置为timeF、fixed、learned |
activation | 激活函数(默认为gelu) |
output_attention | 是否在编码器中输出注意力,使用此参数表示输出注意力(默认为False) |
do_predict | 是否预测未见的未来数据,使用此参数表示进行预测(默认为False) |
mix | 是否在生成解码器中使用混合注意力,使用此参数表示不使用混合注意力(默认为True) |
cols | 数据文件中作为输入特征的某些列 |
num_workers | Data loader的工作数(默认为0) |
itr | 实验次数(默认为2) |
train_epochs | 训练周期(默认为6) |
batch_size | 训练输入数据的批量大小(默认为32) |
patience | 提前停止的耐心(默认为3) |
learning_rate | 优化器学习率(默认为0.0001) |
des | 实验描述(默认为test) |
loss | 损失函数(默认为mse) |
lradj | 调整学习率的方式(默认为type1) |
use_amp | 是否使用自动混合精度训练,使用此参数表示使用amp(默认为False) |
inverse | 是否反转输出数据,使用此参数表示反转输出数据(默认为False) |
use_gpu | 是否使用gpu(默认为True) |
gpu | 用于训练和推理的gpu编号(默认为0) |
use_multi_gpu | 是否使用多个gpu,使用此参数表示使用多个gpu(默认为False) |
devices | 多个gpu的设备ID(默认为0,1,2,3) |
args = parser.parse_args()
#解析命令行参数
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
#检查是否可以使用 GPU
if args.use_gpu and args.use_multi_gpu:
args.devices = args.devices.replace(' ','')
device_ids = args.devices.split(',')
args.device_ids = [int(id_) for id_ in device_ids]
args.gpu = args.device_ids[0]
#如果启用了 GPU 且设置了多 GPU 使用,代码会解析 GPU 设备 ID,并准备相应的 GPU 设置。
data_parser = {
'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},
'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},
'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},
}
'''
这是一个数据解析器字典,包含不同数据集的配置信息
如文件名,目标列,输入输出目标维度
(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)
'''
if args.data in data_parser.keys():
data_info = data_parser[args.data]
args.data_path = data_info['data']
args.target = data_info['T']
args.enc_in, args.dec_in, args.c_out = data_info[args.features]
'''
检查输入的数据集是否在 data_parser 中定义,如果是,则从字典中获取相应的配置。
数据路径、目标列、encoder输入、decoder输入、decoder输出的维度
'''
args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ','').split(',')]
#解析并设置堆叠层的数量
args.detail_freq = args.freq
#将频率详细信息保存在另一个参数中
args.freq = args.freq[-1:]
#频率的最后一个元素(一般是h,s,m这些)
print('Args in experiment:')
print(args)
2.4
Exp = Exp_Informer
for ii in range(args.itr):
# 对于每次迭代,根据实验参数设置进行训练和测试。
setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(args.model, args.data, args.features,
args.seq_len, args.label_len, args.pred_len,
args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor,
args.embed, args.distil, args.mix, args.des, ii)
exp = Exp(args) # 使用给定参数实例化实验对象
print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
exp.train(setting)
print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
exp.test(setting)
#分别对模型进行训练和测试
if args.do_predict:
print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
exp.predict(setting, True)
#如果设置为进行预测,那么执行预测
torch.cuda.empty_cache()