informer 辅助笔记:main_informer.py

运行 informer的主文件

import argparse
import os
import torch

from exp.exp_informer import Exp_Informer

1 参数

parser.add_argument的这些

参数名称 参数描述
model 实验模型。可以设置为informer、informerstack、informerlight(TBD)
data 数据集名称
root_path 数据文件的根路径(默认为./data/ETT/)
data_path 数据文件名称(默认为ETTh1.csv)
features

预测任务(默认为M)。可以设置为M、S、MS

(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)

target S或MS任务中的目标特征(默认为OT)
freq

时间特征编码的频率(默认为h)

可以设置为s(秒)、t(分钟)、h(小时)、d(日)、b(工作日)、w(周)、m(月)。也可以使用更详细的频率,如15min或3h

checkpoints 模型检查点的位置(默认为./checkpoints/)
seq_len Informer编码器的输入序列长度(默认为96)
label_len Informer解码器的起始标记长度(默认为48)
pred_len 预测序列长度(默认为24)
enc_in 编码器输入大小(默认为7)
dec_in 解码器输入大小(默认为7)
c_out 输出大小(默认为7)
d_model 模型的维度(默认为512)
n_heads 头的数量(默认为8)
e_layers 编码器层的数量(默认为2)
d_layers 解码器层的数量(默认为1)
s_layers 堆叠编码器层的数量(默认为3,2,1)
d_ff fcn的维度(默认为2048)
factor Probsparse attn因子(默认为5)
padding 填充类型(默认为0)
distil 是否在编码器中使用提炼,使用此参数表示不使用提炼(默认为True)
dropout 丢弃的概率(默认为0.05)
attn 编码器中使用的注意力(默认为prob)。可以设置为prob(informer)、full(transformer)
embed 时间特征的编码(默认为timeF)。可以设置为timeF、fixed、learned
activation 激活函数(默认为gelu)
output_attention 是否在编码器中输出注意力,使用此参数表示输出注意力(默认为False)
do_predict 是否预测未见的未来数据,使用此参数表示进行预测(默认为False)
mix 是否在生成解码器中使用混合注意力,使用此参数表示不使用混合注意力(默认为True)
cols 数据文件中作为输入特征的某些列
num_workers Data loader的工作数(默认为0)
itr 实验次数(默认为2)
train_epochs 训练周期(默认为6)
batch_size 训练输入数据的批量大小(默认为32)
patience 提前停止的耐心(默认为3)
learning_rate 优化器学习率(默认为0.0001)
des 实验描述(默认为test)
loss 损失函数(默认为mse)
lradj 调整学习率的方式(默认为type1)
use_amp 是否使用自动混合精度训练,使用此参数表示使用amp(默认为False)
inverse 是否反转输出数据,使用此参数表示反转输出数据(默认为False)
use_gpu 是否使用gpu(默认为True)
gpu 用于训练和推理的gpu编号(默认为0)
use_multi_gpu 是否使用多个gpu,使用此参数表示使用多个gpu(默认为False)
devices 多个gpu的设备ID(默认为0,1,2,3)

2 其他部分

2.1 GPU 相关

args = parser.parse_args()
#解析命令行参数

args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
#检查是否可以使用 GPU

if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ','')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]
    #如果启用了 GPU 且设置了多 GPU 使用,代码会解析 GPU 设备 ID,并准备相应的 GPU 设置。


2.2 数据集相关

data_parser = {
    'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},
    'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},
    'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},
}
'''
这是一个数据解析器字典,包含不同数据集的配置信息

如文件名,目标列,输入输出目标维度

(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)
'''


if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]
'''
检查输入的数据集是否在 data_parser 中定义,如果是,则从字典中获取相应的配置。

数据路径、目标列、encoder输入、decoder输入、decoder输出的维度
'''

2.3 设置参数


args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ','').split(',')]
#解析并设置堆叠层的数量

args.detail_freq = args.freq
#将频率详细信息保存在另一个参数中

args.freq = args.freq[-1:]
#频率的最后一个元素(一般是h,s,m这些)

print('Args in experiment:')
print(args)

2.4

Exp = Exp_Informer

for ii in range(args.itr):
    # 对于每次迭代,根据实验参数设置进行训练和测试。
    setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(args.model, args.data, args.features, 
                args.seq_len, args.label_len, args.pred_len,
                args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor, 
                args.embed, args.distil, args.mix, args.des, ii)

    exp = Exp(args) # 使用给定参数实例化实验对象
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    exp.train(setting)
    
    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    exp.test(setting)
    #分别对模型进行训练和测试

    if args.do_predict:
        print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
        exp.predict(setting, True)
    #如果设置为进行预测,那么执行预测

    torch.cuda.empty_cache()

你可能感兴趣的:(python库整理,笔记)