入坑量化投资(一): qlib的使用

qlib的使用目录

笔者炒股票至今也有6年,但一直都是凭借这个人的社会经验进行主观交易;偶然间发现了qlib这个开源的量投资库,于是按官方给的教程稍微梳理了一下。

提示:本文只是按官方教程用自己的数据集走了一遍,但具体如何分析甚至按照程序的执行结果进行投资到现在还是很懵逼的,毕竟我刚接触这一行对很多概念都不懂(具体而言就是虽然我好像读懂了策略,但都找不到qlib在训练完了之后在每个时间段都帮我选了哪些股???就是一堆让我找不着北的曲线);如果有读者对这也很感兴趣,或者解读其中结果的,可以评论或私信与我联系。

文章目录

  • qlib的使用目录
  • 前言
  • 一、利用akshare,baostock获取数据
  • 二、生成数据
  • 3.workfolw


前言

[Qlib官方代码](https://github.com/microsoft/qlib)

代码参考文档:

外部关键包: akshare,baostock(这两个安装过程详见它们的官网)


一、利用akshare,baostock获取数据

示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。

'''

QLIB名称       baostock名称
change          pctChg       涨跌幅
close           close       复权收盘价
factor                      复权因子
high            high        复权最高价
low             low         复权最低价
open            open        复权开盘价
volume          volume       成交量


未做数据清理,例如未剔除低价股、低流动性股票,训练集和测试集较短,未进行参数优化等
'''

import baostock as bs
import pandas as pd
import time
import akshare as ak
import numpy as np

### 第一步_获取所有股票的代码

# 深证A指
all_sz = ak.stock_info_sz_name_code(indicator="A股列表")
# 上证指数
all_sh1 = ak.stock_info_sh_name_code(indicator="主板A股")
all_sh2 = ak.stock_info_sh_name_code(indicator="主板B股")
# 次新股
all_new = ak.stock_zh_a_new()
# 创业板

#风险警示板

#退市股,终止上市

#st股,*st股


df1 = 'sz' + all_sz.A股代码
df2 = 'sh' + all_sh1.COMPANY_CODE
df3 = 'sh' + all_sh2.COMPANY_CODE
df4 = all_new.symbol

t1 = np.array(df1)
t2 = np.array(df2)
t3 = np.array(df3)
t4 = np.array(df4)

stock_sz = np.hstack([t1, 'sz399107'])
stock_sh = np.hstack([t2, t3, 'sh000001'])
stock_new = t4


stock_item = {'深证A指': stock_sz, '上证指数': stock_sh}
stock_item = {'次新股': stock_new}

#### 第二步 登陆baostock系统 ####
lg = bs.login()
# 显示登陆返回信息
# print('login respond error_code:'+lg.error_code)
# print('login respond  error_msg:'+lg.error_msg)

#### 第三步 获取沪深A股历史K线数据 ####
data_start = '2018-01-01'
data_end = '2021-01-22'
# 详细指标参数,参见“历史行情指标参数”章节;“分钟线”参数与“日线”参数不同。“分钟线”不包含指数。
# 分钟线指标:date,time,code,open,high,low,close,volume,amount,adjustflag
# 周月线指标:date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg
# 日线指标 : date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg
for key, value in stock_item.items():
    for single_stock in value:
        if single_stock not in stock_new:###需做数据清洗,去除掉退市股、风险警示股票、次新股等
            ### 默认后复权;复权状态(1:后复权, 2:前复权,3:不复权)
            rs = bs.query_history_k_data_plus(single_stock,
                                              "code,date,open,high,low,close,volume,amount,turn,pctChg",
                                              start_date=data_start,
                                              end_date=data_end,
                                              frequency="d")

            ### 只查询一个价格就可以计算复权因子
            bfq = bs.query_history_k_data_plus(single_stock,
                                               "open",
                                               start_date=data_start,
                                               end_date=data_end,
                                               frequency="d",
                                               adjustflag='3')

            # print('query_history_k_data_plus respond error_code:'+rs.error_code)
            # print('query_history_k_data_plus respond  error_msg:'+rs.error_msg)

            #### 打印结果集 ####
            data_list = []
            while (rs.error_code == '0') & rs.next():
                # 获取一条记录,将记录合并在一起
                imf = rs.get_row_data()
                # 去除中间的·
                imf[0] = imf[0][:2] + imf[0][3:]
                bfq_open = float(bfq.get_row_data()[0])
                hfq_open = float(imf[3])
                factor = '{:.7f}'.format(hfq_open / bfq_open)
                imf.append(factor)
                data_list.append(imf)

            new_columns = rs.fields
            new_columns[-1] = 'change'
            new_columns.append('factor')
            # print(rs.fields)
            result = pd.DataFrame(data_list, columns=new_columns)
            #### 结果集输出到csv文件 ####
            if len(result) > 2 :
                result.to_csv("./csv_data/%s/%s.csv" % (key, single_stock), index=False)

#### 登出系统 ####
bs.logout()

二、生成数据

注意:
‘’’
运行下述dump_all指令,其中包含如下参数:

  1. symbol_field_name:csv文件中股票代码列名,此处为stock_code;
  2. date_field_name:csv文件中日期列名,此处为date;
  3. include_fields:其余字段名,注意逗号后不能有空格,否则数据转换将出现错误。

例如:
python dump_bin.py dump_all --csv_path ./csv_data/sh_data --qlib_dir ./qlib_sh_data/sh_data --symbol_field_name code --date_field_name date --include_fields open,high,low,close,volume,amount,turn,change,factor

python dump_bin.py dump_all --csv_path ./csv_data/sz_data --qlib_dir ./qlib_sz_data/sz_data --symbol_field_name code --date_field_name date --include_fields open,high,low,close,volume,amount,turn,change,factor

‘’’

import abc
import shutil
import traceback
from pathlib import Path
from typing import Iterable, List, Union
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor

import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger


class DumpDataBase:
    INSTRUMENTS_START_FIELD = "start_datetime"
    INSTRUMENTS_END_FIELD = "end_datetime"
    CALENDARS_DIR_NAME = "calendars"
    FEATURES_DIR_NAME = "features"
    INSTRUMENTS_DIR_NAME = "instruments"
    DUMP_FILE_SUFFIX = ".bin"
    DAILY_FORMAT = "%Y-%m-%d"
    HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
    INSTRUMENTS_SEP = "\t"
    INSTRUMENTS_FILE_NAME = "all.txt"
    SAVE_INST_FIELD = "save_inst"

    UPDATE_MODE = "update"
    ALL_MODE = "all"

    def __init__(
        self,
        csv_path: str,
        qlib_dir: str,
        backup_dir: str = None,
        freq: str = "day",
        max_workers: int = 16,
        date_field_name: str = "date",
        file_suffix: str = ".csv",
        symbol_field_name: str = "symbol",
        exclude_fields: str = "",
        include_fields: str = "",
        limit_nums: int = None,
        inst_prefix: str = "",
    ):
        """

        Parameters
        ----------
        csv_path: str
            stock data path or directory
        qlib_dir: str
            qlib(dump) data director
        backup_dir: str, default None
            if backup_dir is not None, backup qlib_dir to backup_dir
        freq: str, default "day"
            transaction frequency
        max_workers: int, default None
            number of threads
        date_field_name: str, default "date"
            the name of the date field in the csv
        file_suffix: str, default ".csv"
            file suffix
        symbol_field_name: str, default "symbol"
            symbol field name
        include_fields: tuple
            dump fields
        exclude_fields: tuple
            fields not dumped
        limit_nums: int
            Use when debugging, default None
        inst_prefix: str
            add a column to the instruments file and record the saved instrument name,
            the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
        """
        csv_path = Path(csv_path).expanduser()
        if isinstance(exclude_fields, str):
            exclude_fields = exclude_fields.split(",")
        if isinstance(include_fields, str):
            include_fields = include_fields.split(",")
        self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
        self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
        self._inst_prefix = inst_prefix.strip()
        self.file_suffix = file_suffix
        self.symbol_field_name = symbol_field_name
        self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
        if limit_nums is not None:
            self.csv_files = self.csv_files[: int(limit_nums)]
        self.qlib_dir = Path(qlib_dir).expanduser()
        self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
        if backup_dir is not None:
            self._backup_qlib_dir(Path(backup_dir).expanduser())

        self.freq = freq
        self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT

        self.works = max_workers
        self.date_field_name = date_field_name

        self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)
        self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)
        self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)

        self._calendars_list = []

        self._mode = self.ALL_MODE
        self._kwargs = {}

    def _backup_qlib_dir(self, target_dir: Path):
        shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))

    def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
        datetime_d = pd.Timestamp(datetime_d)
        return datetime_d.strftime(self.calendar_format)

    def _get_date(
        self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False
    ) -> Iterable[pd.Timestamp]:
        if not isinstance(file_or_df, pd.DataFrame):
            df = self._get_source_data(file_or_df)
        else:
            df = file_or_df
        if df.empty or self.date_field_name not in df.columns.tolist():
            _calendars = pd.Series()
        else:
            _calendars = df[self.date_field_name]

        if is_begin_end and as_set:
            return (_calendars.min(), _calendars.max()), set(_calendars)
        elif is_begin_end:
            return _calendars.min(), _calendars.max()
        elif as_set:
            return set(_calendars)
        else:
            return _calendars.tolist()

    def _get_source_data(self, file_path: Path) -> pd.DataFrame:
        df = pd.read_csv(str(file_path.resolve()), low_memory=False)
        df[self.date_field_name] = df[self.date_field_name].astype(str).astype(np.datetime64)
        # df.drop_duplicates([self.date_field_name], inplace=True)
        return df

    def get_symbol_from_file(self, file_path: Path) -> str:
        return file_path.name[: -len(self.file_suffix)].strip().lower()

    def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
        return (
            self._include_fields
            if self._include_fields
            else set(df_columns) - set(self._exclude_fields)
            if self._exclude_fields
            else df_columns
        )

    @staticmethod
    def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
        return sorted(
            map(
                pd.Timestamp,
                pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),
            )
        )

    def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
        df = pd.read_csv(
            instrument_path,
            sep=self.INSTRUMENTS_SEP,
            names=[
                self.symbol_field_name,
                self.INSTRUMENTS_START_FIELD,
                self.INSTRUMENTS_END_FIELD,
                self.SAVE_INST_FIELD,
            ],
        )

        return df

    def save_calendars(self, calendars_data: list):
        self._calendars_dir.mkdir(parents=True, exist_ok=True)
        calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
        result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
        np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")

    def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
        self._instruments_dir.mkdir(parents=True, exist_ok=True)
        instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
        if isinstance(instruments_data, pd.DataFrame):
            _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
            if self._inst_prefix:
                _df_fields.append(self.SAVE_INST_FIELD)
                instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
                    lambda x: f"{self._inst_prefix}{x}"
                )
            instruments_data = instruments_data.loc[:, _df_fields]
            instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
        else:
            np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")

    def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
        # calendars
        calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
        calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
        cal_df = calendars_df[
            (calendars_df[self.date_field_name] >= df[self.date_field_name].min())
            & (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
        ]
        # align index
        cal_df.set_index(self.date_field_name, inplace=True)
        df.set_index(self.date_field_name, inplace=True)
        r_df = df.reindex(cal_df.index)
        return r_df

    @staticmethod
    def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:
        return calendar_list.index(df.index.min())

    def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):
        if df.empty:
            logger.warning(f"{features_dir.name} data is None or empty")
            return
        # align index
        _df = self.data_merge_calendar(df, self._calendars_list)
        date_index = self.get_datetime_index(_df, calendar_list)
        for field in self.get_dump_fields(_df.columns):
            bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
            if field not in _df.columns:
                continue
            if self._mode == self.UPDATE_MODE:
                # update
                with bin_path.open("ab") as fp:
                    np.array(_df[field]).astype(").tofile(fp)
            elif self._mode == self.ALL_MODE:
                np.hstack([date_index, _df[field]]).astype(").tofile(str(bin_path.resolve()))
            else:
                raise ValueError(f"{self._mode} cannot support!")

    def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
        if isinstance(file_or_data, pd.DataFrame):
            if file_or_data.empty:
                return
            code = file_or_data.iloc[0][self.symbol_field_name].lower()
            df = file_or_data
        elif isinstance(file_or_data, Path):
            code = self.get_symbol_from_file(file_or_data)
            df = self._get_source_data(file_or_data)
        else:
            raise ValueError(f"not support {type(file_or_data)}")
        if df is None or df.empty:
            logger.warning(f"{code} data is None or empty")
            return
        # features save dir
        code = self._inst_prefix + code if self._inst_prefix else code
        features_dir = self._features_dir.joinpath(code)
        features_dir.mkdir(parents=True, exist_ok=True)
        self._data_to_bin(df, calendar_list, features_dir)

    @abc.abstractmethod
    def dump(self):
        raise NotImplementedError("dump not implemented!")

    def __call__(self, *args, **kwargs):
        self.dump()


class DumpDataAll(DumpDataBase):
    def _get_all_date(self):
        logger.info("start get all date......")
        all_datetime = set()
        date_range_list = []
        _fun = partial(self._get_date, as_set=True, is_begin_end=True)
        with tqdm(total=len(self.csv_files)) as p_bar:
            with ProcessPoolExecutor(max_workers=self.works) as executor:
                for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
                    self.csv_files, executor.map(_fun, self.csv_files)
                ):
                    all_datetime = all_datetime | _set_calendars
                    if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
                        _begin_time = self._format_datetime(_begin_time)
                        _end_time = self._format_datetime(_end_time)
                        symbol = self.get_symbol_from_file(file_path)
                        _inst_fields = [symbol.upper(), _begin_time, _end_time]
                        if self._inst_prefix:
                            _inst_fields.append(self._inst_prefix + symbol.upper())
                        date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
                    p_bar.update()
        self._kwargs["all_datetime_set"] = all_datetime
        self._kwargs["date_range_list"] = date_range_list
        logger.info("end of get all date.\n")

    def _dump_calendars(self):
        logger.info("start dump calendars......")
        self._calendars_list = sorted(map(pd.Timestamp, self._kwargs["all_datetime_set"]))
        self.save_calendars(self._calendars_list)
        logger.info("end of calendars dump.\n")

    def _dump_instruments(self):
        logger.info("start dump instruments......")
        self.save_instruments(self._kwargs["date_range_list"])
        logger.info("end of instruments dump.\n")

    def _dump_features(self):
        logger.info("start dump features......")
        _dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
        with tqdm(total=len(self.csv_files)) as p_bar:
            with ProcessPoolExecutor(max_workers=self.works) as executor:
                for _ in executor.map(_dump_func, self.csv_files):
                    p_bar.update()

        logger.info("end of features dump.\n")

    def dump(self):
        self._get_all_date()
        self._dump_calendars()
        self._dump_instruments()
        self._dump_features()


class DumpDataFix(DumpDataAll):
    def _dump_instruments(self):
        logger.info("start dump instruments......")
        _fun = partial(self._get_date, is_begin_end=True)
        new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
        with tqdm(total=len(new_stock_files)) as p_bar:
            with ProcessPoolExecutor(max_workers=self.works) as execute:
                for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
                    if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
                        symbol = self.get_symbol_from_file(file_path).upper()
                        _dt_map = self._old_instruments.setdefault(symbol, dict())
                        _dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
                        _dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
                    p_bar.update()
        _inst_df = pd.DataFrame.from_dict(self._old_instruments, orient="index")
        _inst_df.index.names = [self.symbol_field_name]
        self.save_instruments(_inst_df.reset_index())
        logger.info("end of instruments dump.\n")

    def dump(self):
        self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
        # noinspection PyAttributeOutsideInit
        self._old_instruments = (
            self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
            .set_index([self.symbol_field_name])
            .to_dict(orient="index")
        )  # type: dict
        self._dump_instruments()
        self._dump_features()


class DumpDataUpdate(DumpDataBase):
    def __init__(
        self,
        csv_path: str,
        qlib_dir: str,
        backup_dir: str = None,
        freq: str = "day",
        max_workers: int = 16,
        date_field_name: str = "date",
        file_suffix: str = ".csv",
        symbol_field_name: str = "symbol",
        exclude_fields: str = "",
        include_fields: str = "",
        limit_nums: int = None,
    ):
        """

        Parameters
        ----------
        csv_path: str
            stock data path or directory
        qlib_dir: str
            qlib(dump) data director
        backup_dir: str, default None
            if backup_dir is not None, backup qlib_dir to backup_dir
        freq: str, default "day"
            transaction frequency
        max_workers: int, default None
            number of threads
        date_field_name: str, default "date"
            the name of the date field in the csv
        file_suffix: str, default ".csv"
            file suffix
        symbol_field_name: str, default "symbol"
            symbol field name
        include_fields: tuple
            dump fields
        exclude_fields: tuple
            fields not dumped
        limit_nums: int
            Use when debugging, default None
        """
        super().__init__(
            csv_path,
            qlib_dir,
            backup_dir,
            freq,
            max_workers,
            date_field_name,
            file_suffix,
            symbol_field_name,
            exclude_fields,
            include_fields,
        )
        self._mode = self.UPDATE_MODE
        self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
        self._update_instruments = self._read_instruments(
            self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
        ).to_dict(
            orient="index"
        )  # type: dict

        # load all csv files
        self._all_data = self._load_all_source_data()  # type: pd.DataFrame
        self._update_calendars = sorted(
            filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
        )
        self._new_calendar_list = self._old_calendar_list + self._update_calendars

    def _load_all_source_data(self):
        # NOTE: Need more memory
        logger.info("start load all source data....")
        all_df = []

        def _read_csv(file_path: Path):
            if self._include_fields:
                _df = pd.read_csv(file_path, usecols=self._include_fields)
            else:
                _df = pd.read_csv(file_path)
            if self.symbol_field_name not in _df.columns:
                _df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
            return _df

        with tqdm(total=len(self.csv_files)) as p_bar:
            with ThreadPoolExecutor(max_workers=self.works) as executor:
                for df in executor.map(_read_csv, self.csv_files):
                    if df:
                        all_df.append(df)
                    p_bar.update()

        logger.info("end of load all data.\n")
        return pd.concat(all_df, sort=False)

    def _dump_calendars(self):
        pass

    def _dump_instruments(self):
        pass

    def _dump_features(self):
        logger.info("start dump features......")
        error_code = {}
        with ProcessPoolExecutor(max_workers=self.works) as executor:
            futures = {}
            for _code, _df in self._all_data.groupby(self.symbol_field_name):
                _code = str(_code).upper()
                _start, _end = self._get_date(_df, is_begin_end=True)
                if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
                    continue
                if _code in self._update_instruments:
                    self._update_instruments[_code]["end_time"] = _end
                    futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
                else:
                    # new stock
                    _dt_range = self._update_instruments.setdefault(_code, dict())
                    _dt_range["start_time"] = _start
                    _dt_range["end_time"] = _end
                    futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code

            for _future in tqdm(as_completed(futures)):
                try:
                    _future.result()
                except Exception:
                    error_code[futures[_future]] = traceback.format_exc()
            logger.info(f"dump bin errors: {error_code}")

        logger.info("end of features dump.\n")

    def dump(self):
        self.save_calendars(self._new_calendar_list)
        self._dump_features()
        self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))


if __name__ == "__main__":
    fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})

3.workfolw

提示,这里完全是参考examples/workflow_by_code.ipynb来做的,建议在notebook里运行,结果的可视化操作。

#!/usr/bin/env python
# coding: utf-8

# Open In Colab

# #  Copyright (c) Microsoft Corporation.
# #  Licensed under the MIT License.
# LightGBM选股策略构建
# 
# 下面进入核心的选股策略构建部分,这一部分同样参考Qlib范例workflow_by_code_py,将范例代码中的A股策略改为港股策略。首先,导入后续将会调用的相关Qlib模块,如下图所示

# # 获取数据/制作数据(用py生成,参考第一步获取数据)

# # 调用的相关Qlib模块

# In[1]:


import time
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
    backtest as normal_backtest,
    risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.utils import flatten_dict


# # Data Layer:定义训练的数据集
# https://qlib.readthedocs.io/en/latest/component/data.html#

# In[2]:


# use default data
# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
provider_uri = "./dataset/qlib_sh_data/sh_data"  # target_dir
# if not exists_qlib_data(provider_uri):
#     print(f"Qlib data is not found in {provider_uri}")
#     sys.path.append(str(scripts_dir))
#     from get_data import GetData
#     GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)#初始化


# # 初始化环境和定义股票池的代码

# In[3]:


from qlib.data import D
import time
from qlib.data.filter import ExpressionDFilter
# 股票池,手动删除指数(instruments/all.txt sh000001这行)
expressionDFilter = ExpressionDFilter(rule_expression='$close>1')
instruments = D.instruments(market='all', filter_pipe=[expressionDFilter])
market = instruments

# 指数
benchmark = "SH000001"


# # workflow:训模train model 
# Users could build their own Quant research workflow with these components like Example

# In[4]:


###################################
# train model
###################################

'''
定义因子生成参数data_handler_config
data_handler_config相当于配置文件,字典类型,
用来定义完整数据起止日期(start_time和end_time),
拟合数据起止日期(fit_start_time和fit_end_time),
股票池(instruments)等。
拟合数据起止日期区间应为完整数据起止日期数据的子集
'''
data_handler_config = {
        "start_time": "2018-01-01",
        "end_time": "2021-01-22",
        "fit_start_time": "2018-01-01",
        "fit_end_time": "2020-06-30",
        "instruments": market,
}

'''
定义模型训练参数task
分为model和dataset
'''

# 第一项model为AI模型参数,
# 必须包含class(AI模型名称)和module_path(AI模型所在路径)两个子键;
# kwargs为model的可选子键,通过kwargs设置指定AI模型的超参数

task = {
    "model": {
        "class": "LGBModel",#A模型名称,此处为 LightGBM

        "module_path": "qlib.contrib.model.gbdt",#A模型所在路径
        "kwargs": {#LightGBM模型超参数
            "loss": "mse",#損失函数,此处为均方误差
            "colsample_bytree": 0.8879,#列采样比例
            "learning_rate": 0.0421,#学习率
            "subsample": 0.8789,#行采样比例
            "lambda_l1": 205.6999,#L1正则化惩罚系数
            "lambda_l2": 580.9768,#L2正则化惩罚系数
            "max_depth": 8,#最大树深
            "num_leaves": 210,#最大叶子节点数
            "num_threads": 20,#最大并行线程数
        },
    },
    

# 第二项dataset为数据集参数,必须包含class(数据集名称)和
# module_path(数据集所在路径)两个子键;kwargs为dataset的可选子键,
# 通过kwargs设置指定数据集的参数

    "dataset": {
        "class": "DatasetH",#数据集名称
        "module_path": "qlib.data.dataset",#数据集所在路径
        "kwargs": {#Dataset模型参数
            "handler": {#因子库参数,处理器
                "class": "Alpha158",#因子库名称,继承自DataHandlerLP
                "module_path": "qlib.contrib.data.handler",#因子库路径
                "kwargs": data_handler_config,#Apha158因子库参数
            },
            "segments": {#时间划分参数
                "train": ("2018-01-01", "2019-12-31"),#训练集时间区间
                "valid": ("2020-01-01", "2020-06-30"),#验证集时间区间
                "test": ("2020-07-01", "2021-01-22"),#测试集时间区间
            },
        },
    },
}

# model initiaiton, 实例化模型对象
model = init_instance_by_config(task["model"])
#实例化因子库数据集,从基础行情数据计算出的包含所有特征因子和标签值的数据集
dataset = init_instance_by_config(task["dataset"])



# start exp to train model
'''
正式进行训练
读入数据及数据预处理时间开销约为54.120秒,训练模型时间开销约为118秒。
LightGBM模型迭代的实质是参数优化,
当验证集损失连续50轮未降低时停止迭代,此处迭代71次后结束训练
''' 

t_start = time.time()

#qlib.workflow.start开启训练;R是流管理器
with R.start(experiment_name="train_model"):
    R.log_params(**flatten_dict(task))
    
    #model.fit拟合模型;
    model.fit(dataset)
    
    #qlib.workflow.save_objects保存模型;
    R.save_objects(trained_model=model)
    
    #qlib.workflow.get_recorder().id获取“实验”(即模型训练)记录的编号
    rid = R.get_recorder().id
t_end = time.time()
print('训模耗时: %.3fs'%(t_end-t_start))


# # 选股策略回测
# 
# Forecast/prediction,Portfolio, backtest(Intraday Trading) & analysis

# In[9]:


###################################
# prediction, backtest & analysis
###################################

'''
接下来设置策略回测参数port_analysis_config,
该参数为字典类型,又可以分为strategy和backtest两个子键。

第一项strategy为策略参数,
例如此处使用TopkDropout策略,
每日等权持有topk=50只股票,
同时每日卖出持仓股票中最新预测收益最低的n_drop=5只股票,
买入未持仓股票中最新预测收益最高的n_drop=5只股票。

第二项backtest为策略回测参数,
用于设置涨跌停限制、起始资金、业绩比较基准、成交价格、交易费率等信息

'''
port_analysis_config = {
    "strategy": {#策略参数
        "class": "TopkDropoutStrategy",#策略名称
        "module_path": "qlib.contrib.strategy.strategy",#策略所在路径
        "kwargs": {#TopkDropout策略参数
            "topk": 5,#每日持仓股票数
            "n_drop": 2,#每日换仓股票数
        },
    },
    "backtest": {#回测参数
        "verbose": False,#是否实时显示回测信息
        "limit_threshold": 0.095,#涨跌停限制,港股不设涨跌停
        "account": 100000,#起始资金
        "benchmark": benchmark,#业绩比较基准,此处为上证300
        "deal_price": "close",#成交价格,此处为收盘价
        "open_cost": 0.0005,#开仓交易费率
        "close_cost": 0.0015,#平仓交易费率
        "min_cost": 1000,#最低交易费用
    },
}


# backtest and analysis
'''
回测代码实现如上图所示。
调用qlib.workflow模块正式进行回测,依次执行如下命令:
qlib.workflow.record_temp.SignalRecord初始化调仓信号;
qlib.workflow.record_temp.PortAnaRecord初始化回测及绩效分析;
'''

#qlib.workflow.start开启回测;
with R.start(experiment_name="backtest_analysis"):
    #qlib.workflow.get_recorder获取此前模型训练“实验”记录;
    recorder = R.get_recorder(rid, experiment_name="train_model")
    
    #recorder.load_object读取模型;
    model = recorder.load_object("trained_model")

    #qlib.workflow.get_recorder初始化回测“实验”记录;
    recorder = R.get_recorder()  
    ba_rid = recorder.id
    
    #回传给recorder
    sr = SignalRecord(model, dataset, recorder)
    #生成预测结果,利用recorder记录结果
    sr.generate()

    # 在测试集上执行回测
    #创建组合分析记录器,其中实验记录器把预测值带进来,并记录最终回测结果
    par = PortAnaRecord(recorder, port_analysis_config)
    par.generate()#par.generate生成回测及绩效分析结果
    
'''
回测代码运行过程中,还显示部分预测结果。
例如在测试集第一个交易日(2020年7月1日)对个股下期收益的预测值,
如浦发银行(SH600000)预测值为-0.016728;
又如不扣费(without cost)及扣费(with cost)后的
日均收益mean、日度波动率std 、年化收益annualized_return、信息比率information_ratio和最大回撤max_drawdown
'''


# # analyze graphs
# 回测和绩效分析结果展示

# In[11]:


'''
完成策略回测后,调用qlib.contrib.report模块展示回测和绩效分析结果。
展示前首先执行qlib.workflow.get_recorder获取回测“实验”记录,
相关结果均储存为pkl格式,
执行recorder.load_object读取
预测结果pred.pkl、
回测报告report_normal.pkl、
仓位情况positions_normal.pkl
和持仓分析port_analysis.pkl
'''
from qlib.contrib.report import analysis_model, analysis_position
from qlib.data import D
recorder = R.get_recorder(ba_rid, experiment_name="backtest_analysis")
pred_df = recorder.load_object("pred.pkl")
pred_df_dates = pred_df.index.get_level_values(level='datetime')
report_normal_df = recorder.load_object("portfolio_analysis/report_normal.pkl")
positions = recorder.load_object("portfolio_analysis/positions_normal.pkl")
analysis_df = recorder.load_object("portfolio_analysis/port_analysis.pkl")

import pandas as pd
import pickle
# 重点是rb和r的区别,rb是打开2进制文件,文本文件用r
f = open("./mlruns/1/dd09d727960a419ca87fa8882664452b/artifacts/pred.pkl",'rb')
data = pickle.load(f)
pd.set_option('display.width',None)
pd.set_option('display.max_rows',None)
pd.set_option('display.max_colwidth',None)
print(data)
inf=str(data)
ft = open('回测值.csv', 'w')
ft.write(inf)


# ## analysis position

# ### report

# In[12]:


analysis_position.report_graph(report_normal_df)


# ### risk analysis

# In[13]:


analysis_position.risk_analysis_graph(analysis_df, report_normal_df)


# ## analysis model

# In[14]:


label_df = dataset.prepare("test", col_set="label")
label_df.columns = ['label']


# ### score IC
# 
# 执行下列命令展示AI模型预测个股收益的IC和Rank IC值,可视化结果如下图所示。
# 
# 

# In[15]:


pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)
analysis_position.score_ic_graph(pred_label)


# ### model performance

# In[16]:


analysis_model.model_performance_graph(pred_label)







你可能感兴趣的:(量化投资,python,qlib)