informer源码注释详情记录

目录

  • main_informer.py 源码详情注释

main_informer.py 源码详情注释

import argparse
import datetime
import json
import os
import shutil
import sys

import pandas as pd
import torch

from exp.exp_informer import Exp_Informer
from utils.visualization import *
from utils.initialize_random_seed import *
from pyecharts.globals import CurrentConfig, OnlineHostType


def init_parser():
    parser = argparse.ArgumentParser(description='[Informer] Long Sequences Forecasting')
    parser.add_argument('--model', type=str, required=False, default='informer',
                        help='model of experiment, options: [informer, informerstack, informerlight(TBD)]')

    parser.add_argument('--data', type=str, required=False, default='C5', help='data them--------------//')
    parser.add_argument('--root_path', type=str, default='./data/C5/',
                        help='数据文件的根路径(root path of the data file)---------------//')
    # parser.add_argument('--root_path', type=str, default='./data/ETT/', help='数据文件的根路径(root path of the data file)')
    # parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
    parser.add_argument('--data_path', type=str, default='C5#21.csv', help='data file--------------//')
    parser.add_argument('--features', type=str, default='MS',
                        help='预测任务选项(forecasting task, options):[M, S, MS]; ------------------//'
                             'M:多变量预测多元(multivariate predict multivariate), '
                             'S:单变量预测单变量(univariate predict univariate), '
                             'MS:多变量预测单变量(multivariate predict univariate)')
    parser.add_argument('--target', type=str, default='NH3', help='S或MS任务中的目标特征列名(target feature in S or MS task)//')
    parser.add_argument('--freq', type=str, default='15t',
                        help='时间特征编码的频率(freq for time features encoding), ---------------------//'
                             '选项(options):[s:secondly, t:minutely, h:hourly, d:daily, b:工作日(business days), w:weekly, m:monthly], '
                             '你也可以使用更详细的频率,比如15分钟或3小时(you can also use more detailed freq like 15t or 3h)')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/',
                        help='模型检查点的位置(location of model checkpoints)')

    # seq_len其实就是n个滑动窗口的大小,pred_len就是一个滑动窗口的大小
    # [24,36,48,96,128]
    parser.add_argument('--seq_len', type=int, default=96,
                        help='Informer编码器的输入序列长度(input sequence length of Informer encoder)原始默认为96++++++++++++')
    parser.add_argument('--label_len', type=int, default=48,
                        help='inform解码器的开始令牌长度(start token length of Informer decoder),原始默认为48++++++++++++++')
    parser.add_argument('--pred_len', type=int, default=76,
                        help='预测序列长度(prediction sequence length)原始默认为24++++++++++++++++++')
    # pred_len就是要预测的序列长度(要预测未来多少个时刻的数据),也就是Decoder中置零的那部分的长度
    # Informer decoder input: concat[start token series(label_len), zero padding series(pred_len)]

    parser.add_argument('--enc_in', type=int, default=5, help='编码器输入大小(encoder input size)')
    parser.add_argument('--dec_in', type=int, default=5, help='解码器输入大小(decoder input size)')
    parser.add_argument('--c_out', type=int, default=1, help='输出尺寸(output size)')

    parser.add_argument('--d_model', type=int, default=16, help='模型维数(dimension of model)默认是512********************')
    parser.add_argument('--n_heads', type=int, default=8, help='(num of heads)***************')
    parser.add_argument('--e_layers', type=int, default=2, help='编码器层数(num of encoder layers)*************')
    parser.add_argument('--d_layers', type=int, default=1, help='解码器层数(num of decoder layers)*******************')
    parser.add_argument('--s_layers', type=str, default='3,2,1',
                        help='堆栈编码器层数(num of stack encoder layers)*****************')
    parser.add_argument('--d_ff', type=int, default=64, help='fcn维度(dimension of fcn),默认是2048****************')

    parser.add_argument('--factor', type=int, default=5, help='probsparse attn factor')
    parser.add_argument('--padding', type=int, default=0, help='padding type')
    parser.add_argument('--distil', action='store_true', help='是否在编码器中使用蒸馏,使用此参数意味着不使用蒸馏'
                                                               '(whether to use distilling in encoder, using this argument means not using distilling)',
                        default=True)
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout******************')
    parser.add_argument('--seed', type=int, default=12345, help='random seed 随机数种子')
    parser.add_argument('--random_choos', type=bool, default=True, help='random seed 随机数种子')
    parser.add_argument('--attn', type=str, default='prob', help='用于编码器的注意力机制,选项:[prob, full]'
                                                                 '(attention used in encoder, options:[prob, full])')
    """
    enc_in: informer的encoder的输入维度
    dec_in: informer的decoder的输入维度
    c_out: informer的decoder的输出维度
    d_model: informer中self-attention的输入和输出向量维度
    n_heads: multi-head self-attention的head数
    e_layers: informer的encoder的层数
    d_layers: informer的decoder的层数
    d_ff: self-attention后面的FFN的中间向量表征维度
    factor: probsparse attention中设置的因子系数
    padding: decoder的输入中,作为占位的x_token是填0还是填1
    distil: informer的encoder是否使用注意力蒸馏
    attn: informer的encoder和decoder中使用的自注意力机制
    embed: 输入数据的时序编码方式
    activation: informer的encoder和decoder中的大部分激活函数
    output_attention: 是否选择让informer的encoder输出attention以便进行分析
    
    小数据集的预测可以先使用默认参数或适当减小d_model和d_ff的大小

    """
    # 时间特征编码【未知】
    parser.add_argument('--embed', type=str, default='timeF', help='时间特征编码,选项:[timeF, fixed, learned]'
                                                                   '(time features encoding, options:[timeF, fixed, learned])')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--output_attention', action='store_true', help='是否在编码器中输出注意力'
                                                                        '(whether to output attention in ecoder)')
    parser.add_argument('--do_predict', action='store_true', default=True, help='是否预测看不见的未来数据'
                                                                                '(whether to predict unseen future data)*********************')
    parser.add_argument('--mix', action='store_false', help='在生成解码器中使用混合注意力'
                                                            '(use mix attention in generative decoder)', default=True)
    parser.add_argument('--cols', type=str, nargs='+', help='将数据文件中的某些cols作为输入特性'
                                                            '(certain cols from the data files as the input features)')
    parser.add_argument('--num_workers', type=int, default=0, help='工作的数据加载器数量'
                                                                   'data loader num workers')
    parser.add_argument('--itr', type=int, default=5, help='次实验------------'
                                                           'experiments times')
    """
    """

    parser.add_argument('--train_epochs', type=int, default=200, help='train epochs------------')
    parser.add_argument('--batch_size', type=int, default=64, help='训练输入数据的批大小-------------'
                                                                   'batch size of train input data')
    parser.add_argument('--patience', type=int, default=5, help='提前停止的连续轮数'
                                                                'early stopping patience')
    parser.add_argument('--learning_rate', type=float, default=0.001,
                        help='optimizer learning rate*******************')
    parser.add_argument('--des', type=str, default='test', help='实验描述'
                                                                'exp description')

    parser.add_argument('--loss', type=str, default='mse', help='loss function')
    parser.add_argument('--loss_jude', type=str, default='mse', help='loss function')

    parser.add_argument('--lradj', type=str, default='type1', help='校正的学习率'
                                                                   'adjust learning rate')
    parser.add_argument('--use_amp', action='store_true', help='使用自动混合精度训练'
                                                               'use automatic mixed precision training', default=True)
    parser.add_argument('--output', type=str, default='./output', help='输出路径')

    # 想要获得最终预测的话这里应该设置为True;否则将是获得一个标准化的预测。
    parser.add_argument('--inverse', action='store_true', help='逆标准化输出数据'
                                                               'inverse output data', default=True)

    parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
    parser.add_argument('--gpu', type=int, default=0, help='gpu')
    parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
    parser.add_argument('--devices', type=str, default='0,1', help='device ids of multile gpus')

    # 进行parser的变量初始化,获取实例。
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = init_parser()
    # print(torch.cuda.is_available())
    # 判断GPU是否能够使用,并获取标识
    args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

    # 判断是否使用多块GPU,默认不使用多块GPU
    if args.use_gpu and args.use_multi_gpu:
        # 获取显卡列表,type:str
        args.devices = args.devices.replace(' ', '')
        # 拆分显卡获取列表,type:list
        device_ids = args.devices.split(',')
        # 转换显卡id的数据类型
        args.device_ids = [int(id_) for id_ in device_ids]
        # 获取第一块显卡
        args.gpu = args.device_ids[0]

    # 初始化数据解析器,用于定义训练模式、预测模式、数据粒度的初始化选项。
    """
    字典格式:{数据主题:{data:数据路径,'T':目标字段列名,'M':,'S':,'MS':}}

    'M:多变量预测多元(multivariate predict multivariate)''S:单变量预测单变量(univariate predict univariate)''MS:多变量预测单变量(multivariate predict univariate)'"""
    data_parser = {
        'ETTh1': {'data': 'ETTh1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTh2': {'data': 'ETTh2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTm1': {'data': 'ETTm1.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'ETTm2': {'data': 'ETTm2.csv', 'T': 'OT', 'M': [7, 7, 7], 'S': [1, 1, 1], 'MS': [7, 7, 1]},
        'WTH': {'data': 'WTH.csv', 'T': 'WetBulbCelsius', 'M': [12, 12, 12], 'S': [1, 1, 1], 'MS': [12, 12, 1]},
        'ECL': {'data': 'ECL.csv', 'T': 'MT_320', 'M': [321, 321, 321], 'S': [1, 1, 1], 'MS': [321, 321, 1]},
        'Solar': {'data': 'solar_AL.csv', 'T': 'POWER_136', 'M': [137, 137, 137], 'S': [1, 1, 1], 'MS': [137, 137, 1]},
        'electric_power': {'data': 'electric.csv', 'T': 'spg', 'M': [5, 5, 5], 'S': [1, 1, 1], 'MS': [5, 5, 1]},
        'C5': {'data': 'C5#21.csv', 'T': 'NH3', 'M': [4, 4, 4], 'S': [1, 1, 1], 'MS': [4, 4, 1]},
    }

    # 判断在parser中定义的数据主题是否在解析器中
    if args.data in data_parser.keys():
        # 根据args里面定义的数据主题,获取对应的初始化数据解析器info信息,type:dict
        data_info = data_parser[args.data]
        # 获取该数据主题的数据文件的路径
        args.data_path = data_info['data']
        # 从数据解析器中获取 S或MS任务中的目标特征列名。
        args.target = data_info['T']
        # 从数据解析器中 根据变量features的初始化信息 获取 编码器输入大小,解码器输入大小,输出尺寸
        args.enc_in, args.dec_in, args.c_out = data_info[args.features]

    # 堆栈编码器层数,type:list
    args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ', '').split(',')]
    # 时间特征编码的频率,就是进行特征工程的时候时间粒度选取多少
    args.detail_freq = args.freq
    args.freq = args.freq[-1:]

    print('Args in experiment:')
    print(args.freq)
    print(args)
    now_time = datetime.datetime.now().strftime('%mM_%dD %HH:%Mm:%Ss').replace(" ", "_").replace(":", "_")

    # 获取模型实例
    Exp = Exp_Informer

    # 获取page实例
    page_loss = get_page_loss(args.itr)
    page_pt = get_page_value(args.itr)
    page_p = get_page_noTrue(args.itr)


    # 未来的那段时间的真实值
    def get_true_month(sheet_name):
        data = pd.read_excel('./TrueValue/C5真实值.xls',sheet_name=sheet_name)
        true_month = [round(i, 3) for i in data["NH3"].values.tolist()]
        return true_month

    def get_true_date(sheet_name):
        data = pd.read_excel('./TrueValue/C5真实值.xls',sheet_name=sheet_name)
        true_date = [round(i, 3) for i in data["NH3"].values.tolist()]
        return true_date
    them = "11日"
    # 存储数据的字典,为了将预测和均值和真实值存储到本地,(若是没有真实值,那么不存储真实值)
    data_dict = dict()
    # 存储未来预测值的真实数据,为了做可视化和评估未来
    true = []
    # 存储模型信息的json文件
    info_dict = dict()
    # 存储预测未来的时候生成的时间
    pred_dates = []

    if data_parser["C5"]["data"] == "C5#21.csv":
        # args.batch_size = 32
        them = '11日'
        # args.freq = 'h'
        try:
            true_date = get_true_date(them)
            # args.pred_len = len(true_date)
            true = true_date
            data_dict["true"] = true_date
        except:
            print("提示:由于未来还没有发生,在真实值数据中没有这个月份数据,故而无法画出未来预测值~未来值的对比图!")
        finally:
            print("Program to continue!>>>")
    else:
        # args.batch_size = 32
        them = '11日'
        # args.freq = 'h'
        try:
            true_month = get_true_month(them)
            # args.pred_len = len(true_month)
            true = true_month
            data_dict["true"] = true_month
        except:
            print("提示:由于未来还没有发生,在真实值数据中没有这个月份数据,故而无法画出未来预测值~未来值的对比图!")
        finally:
            print("Program to continue!>>>")

    # 构建单次运行的存储路径:
    run_name_dir_old = them + "_" + args.model + "_" + data_parser["C5"]["data"][
        -4] + "_" + now_time + "_" + args.data
    args.output = os.path.join(args.output, data_parser["C5"]["data"][-4])
    run_name_dir = os.path.join(args.output, run_name_dir_old)
    if not os.path.exists(run_name_dir):
        os.makedirs(run_name_dir)
    # 单次运行的n个实验的模型存储的路径:需要判断是否存在,训练的时候已经判断了
    args.checkpoints = os.path.join(args.checkpoints, data_parser["C5"]["data"][:-4])
    run_name_dir_ckp = os.path.join(args.checkpoints, run_name_dir_old)

    # 存储整个实验的info信息
    info_file = os.path.join(run_name_dir, "{}_info_{}_{}.json".format(them, args.model, args.data))

    is_show_label_sign = data_parser["C5"]["data"][-4]

    df_columns = []
    # 要进行多少次实验,一次实验就是完成一个模型的训练-测试-预测 过程。默认2for ii in range(args.itr):
        run_ex_dir = os.path.join(run_name_dir, "第_{}_次实验记录".format(ii + 1))
        if args.random_choos == True:
            pass
        else:
            setup_seed(args.seed)
        if not os.path.exists(run_ex_dir):
            os.makedirs(run_ex_dir)
        # 添加实验info
        info_dict["实验序号"] = ii
        info_dict["model"] = args.model
        info_dict["data_them"] = args.data
        info_dict["编码器的输入序列长度【滑动窗口大小】"] = args.seq_len
        info_dict["预测序列长度"] = args.pred_len
        info_dict["时间特征编码的频率【数据粒度】freq"] = args.freq
        info_dict["dorpout"] = args.dropout
        info_dict["batch_size"] = args.batch_size
        info_dict["损失函数loss"] = args.loss
        info_dict["提前停止的连续轮数patience"] = args.patience
        info_dict["随机种子"] = args.seed
        info_dict["是否随机选择"] = args.random_choos

        # 实验设置记录要点,方便打印,同时也作为文件名字传入参数,setting record of experiments
        setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(args.model,
                                                                                                             args.data,
                                                                                                             args.features,
                                                                                                             args.seq_len,
                                                                                                             args.label_len,
                                                                                                             args.pred_len,
                                                                                                             args.d_model,
                                                                                                             args.n_heads,
                                                                                                             args.e_layers,
                                                                                                             args.d_layers,
                                                                                                             args.d_ff,
                                                                                                             args.attn,
                                                                                                             args.factor,
                                                                                                             args.embed,
                                                                                                             args.distil,
                                                                                                             args.mix,
                                                                                                             args.des,
                                                                                                             ii)
        # 设置实验,将数据参数和模型变量传入实例
        exp = Exp(args)  # set experiments

        # 训练模型
        print('>>>>>>>start training :  {}  >>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
        model, info_dict, all_epoch_train_loss, all_epoch_vali_loss, all_epoch_test_loss, epoch_count = exp.train(
            setting, info_dict, run_name_dir_ckp, run_ex_dir, args)

        # 模型测试
        print('>>>>>>>testing :  {}  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
        info_dict, test_pred, test_true = exp.test(setting, info_dict, run_ex_dir, args)
        # print(test_pred)
        # print(test_true)

        future_pred, pred_date = 0, 0
        # 做预测
        if args.do_predict:
            print('>>>>>>>predicting :  {}  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
            # 模型预测未来
            future_pred, pred_date = exp.predict(setting, run_name_dir_ckp, run_ex_dir, args, True)
            pred_dates = pred_date
            # print("未来的预测值:",future_pred)
            # print("未来的时间范围:",pred_date)

        # 存储实验的info信息:
        with open(info_file, mode='a', encoding='utf-8') as f:
            json.dump(info_dict, f, indent=4, ensure_ascii=False)
        # 存储数据:
        df = pd.DataFrame(data={"NH3": future_pred, "time": pred_date}, columns=["NH3", "time"])
        df.to_csv(os.path.join(run_ex_dir, "第{}次实验_未来预测结果.csv".format(ii + 1)), index=False, encoding='utf-8')

        # 存储预测结果到字典
        data_dict["实验{}".format(ii + 1)] = future_pred
        # 添加字段名字
        df_columns.append("实验{}".format(ii + 1))
        # 可视化:

        line_p = chart_predict(pred_date, future_pred, run_ex_dir, args, ii + 1, is_show_label_sign)
        line_loss = chart_loss(all_epoch_train_loss, all_epoch_vali_loss, all_epoch_test_loss, epoch_count, run_ex_dir,
                               args, ii + 1)
        # 将预测的预测和未来的真实值一起进行可视化
        if true != []:
            line_pt = chart_predict_and_true(pred_date, future_pred, true, run_ex_dir, args, ii + 1, is_show_label_sign)
            page_pt.add(line_pt)

        # 将图表加入page
        page_loss.add(line_loss)
        page_p.add(line_p)

        # 清除cuda的缓存
        torch.cuda.empty_cache()
        # 删除存储的模型
        shutil.rmtree(run_name_dir_ckp)

    # 可视化page
    page_loss.render(os.path.join(run_name_dir, "训练-验证-损失可视化-test.html"))
    page_loss.save_resize_html(source=os.path.join(run_name_dir, "训练-验证-损失可视化-test.html"),
                               cfg_file=os.path.join('./output/', "chart_config.json"),
                               dest=os.path.join(run_name_dir, "训练-验证-损失可视化.html"))

    page_p.render(os.path.join(run_name_dir, "predict-test.html"))
    page_p.save_resize_html(source=os.path.join(run_name_dir, "predict-test.html"),
                            cfg_file=os.path.join('./output/', "chart_config.json"),
                            dest=os.path.join(run_name_dir, "predict.html"))

    # 存储字典文件
    data_dict["date"] = pred_dates
    # print(data_dict['true'])
    # print(data_dict['实验1'])
    # print(data_dict['date'])
    # print(len(data_dict['true']))
    # print(len(data_dict['实验1']))
    # print(len(data_dict['date']))
    # sys.exit()
    df = pd.DataFrame(data_dict)
    df["pred_mean_NH3"] = df[df_columns].mean(axis=1)
    df_columns.insert(0, 'pred_mean_NH3')

    if true != []:
        df_columns.insert(0, 'true')
        page_pt.render(os.path.join(run_name_dir, "predict-true-test.html"))
        page_pt.save_resize_html(source=os.path.join(run_name_dir, "predict-true-test.html"),
                                 cfg_file=os.path.join('./output/', "chart_config.json"),
                                 dest=os.path.join(run_name_dir, "predict-true.html"))
    df_columns.insert(0, 'date')
    df = df[df_columns]
    df["pred_mean_NH3"] = round(df["pred_mean_NH3"], 3)
    df.to_csv(os.path.join(run_name_dir, "{}次实验_{}未来预测结果.csv".format(args.itr, them)), index=False, encoding='utf-8',
              sep=',')
    # print(df)

你可能感兴趣的:(论文源码,python,深度学习,开发语言)