微软新出了一个Q_Lib量化平台,导师让我们研究一下,我看一遍使用手册,很多API都不会用哎。
print('----------0install--------------')
# 安装成功
import qlib
print(qlib.__version__)
print('----------1download.init--------------')
# 下载数据初始化
from qlib.config import REG_CN,REG_US
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data",region=REG_CN)
# qlib.init(provider_uri="~/.qlib/qlib_data/us_data",region=REG_US)
# provider_url参数有两个:"~/.qlib/qlib_data/cn_data","~/.qlib/qlib_data/us_data"
# region参数有两个:REG_CN REG_US
print('----------2download.ex--------------')
# 下载数据
from qlib.data import D
a=D.calendar(start_time='2020-01-01',end_time='2020-12-31',freq='day')[:10]
b=D.instruments(market='all')
instruments = D.instruments(market='csi300')
c=D.list_instruments(instruments=instruments, start_time='2010-01-01', end_time='2020-12-31', as_list=True)[:6]
print(a,b,c,sep='\n')
from qlib.data.filter import NameDFilter
nameDFilter = NameDFilter(name_rule_re='SH[0-9]{4}55')
instruments = D.instruments(market='csi300', filter_pipe=[nameDFilter])
d=D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)
print(d)
from qlib.data.filter import ExpressionDFilter
expressionDFilter = ExpressionDFilter(rule_expression='$close>2000')
instruments = D.instruments(market='csi300', filter_pipe=[expressionDFilter])
e=D.list_instruments(instruments=instruments, start_time='2015-01-01', end_time='2016-02-15', as_list=True)
print(e)
instruments = ['SH600000']
fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
f=D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day').head()
print(f,type(f))
instruments = ['SH600000','SZ000001']
fields = ['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']
g=D.features(instruments, fields, start_time='2010-01-01', end_time='2017-12-31', freq='day')
print(g,type(g),g.index)
# type(g) =
# g.index = MultiIndex([……],names=['instrument', 'datetime'], length=3794)
print('----------3customerclass.ex--------------')
# 下面讲的是自设类
from qlib.model.base import Model
import pandas
class CJ_model(Model):
def __init__(self, loss='mse', **kwargs):
if loss not in {'mse', 'binary'}: raise NotImplementedError
self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score
self._params.update(objective=loss, **kwargs)
self._model = None
def fit(self, num_boost_round=1000, **kwargs):
pass
def predict(self, **kwargs) -> pandas.Series:
pass
def finetune(self, num_boost_round=10, verbose_eval=20):
pass
# 下面是下载Alpha158数据
import qlib
from qlib.contrib.data.handler import Alpha158
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
}
if __name__ == "__main__":
qlib.init()
h = Alpha158(**data_handler_config)
print(type(h))
print(h.get_cols())
print(h.fetch(col_set="label"))
print(h.fetch(col_set="feature"))
print('----------4workflow.ex--------------')
# 在.py文件内使用工作流,求得分
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# start exp
with R.start(experiment_name="workflow"):
# train
R.log_params(**flatten_dict(task))
model.fit(dataset)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
print('----------5dataAPI--------------')
from qlib.data.data import CalendarProvider
class subcla(CalendarProvider):
def calendar(self, start_time, end_time, freq, future):
pass
d= subcla()
print(1,d.calendar(start_time='2010-01-01', end_time='2020-12-31', freq='day', future=False))
print(2,d.locate_index(start_time='2010-01-01', end_time='2020-12-31', freq='day', future=False))
from qlib.data.data import InstrumentProvider
print(3,InstrumentProvider.instruments(market='all', filter_pipe=None))
from qlib.data.data import FeatureProvider
from qlib.data.data import ExpressionProvider
from qlib.data.data import DatasetProvider
print(4,DatasetProvider. get_instruments_d(instruments=['SH600000','SH600001'],freq='day'))
import pandas as pd
print(5,DatasetProvider. get_column_names(pd.read_csv('1.csv')))
print(6,DatasetProvider.parse_fields(fields=['$close', '$volume', 'Ref($close, 1)', 'Mean($close, 3)', '$high-$low']))
print(7,DatasetProvider. dataset_processor(instruments_d='SH600000',column_names=['$close'], start_time='2010-01-01', end_time='2020-12-31', freq='day'))
from qlib.data.data import LocalCalendarProvider
e=LocalCalendarProvider()
print(8,e.calendar(start_time='2010-01-01', end_time='2020-12-31', freq='day', future=False))
from qlib.data.data import LocalInstrumentProvider
f=LocalInstrumentProvider()
print(9,f.list_instruments(instruments=D.instruments(market='csi300'), start_time='2010-01-01', end_time='2020-12-31', freq='day', as_list=False))
from qlib.data.data import LocalFeatureProvider
g=LocalFeatureProvider()
print(10,g.feature(instrument='SH600000', field=['$close'],start_index='2010-01-01', end_index='2020-12-31', freq='day', ))
from qlib.data.data import LocalExpressionProvider
h=LocalExpressionProvider()
print(11,h.expression(instrument='SH600000', field='$close',start_time='2010-01-01', end_time='2020-12-31', freq='day'))
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
from qlib.data.data import LocalDatasetProvider
i=LocalDatasetProvider()
print(12,i.dataset(instruments=['SH600000'], fields=['$close'],start_time='2010-01-01', end_time='2020-12-31', freq='day'))
print(13,LocalDatasetProvider. multi_cache_walker(instruments=['SH600000'], fields=['$close'],start_time='2010-01-01', end_time='2020-12-31', freq='day'))
print('----------6filterAPI--------------')
from qlib.data.filter import BaseDFilter
from qlib.data.filter import SeriesDFilter
from qlib.data.filter import NameDFilter
from qlib.data.data import InstrumentProvider
cj_instruments=InstrumentProvider.instruments(market='csi300', filter_pipe=[NameDFilter(name_rule_re='SH[0-9]{4}55')])
print(14,f.list_instruments(instruments=cj_instruments, start_time='2010-01-01', end_time='2020-12-31', freq='day', as_list=False))
from qlib.data.filter import ExpressionDFilter
print(15,ExpressionDFilter(rule_expression = '$close /$open > 5',fstart_time=None, fend_time=None, keep=False))
print(16,ExpressionDFilter(rule_expression = '$rank($close) < 10', fstart_time=None, fend_time=None, keep=False))
print(17,ExpressionDFilter(rule_expression = '$Ref($close, 3) > 100', fstart_time=None, fend_time=None, keep=False))
print('----------7baseAPI--------------')
print('----------8operatorAPI--------------')
from qlib.data.ops import Abs,Sign,Log,Power,Lt
from qlib.data.data import LocalExpressionProvider
h=LocalExpressionProvider()
x1=h.expression(instrument='SH600000', field='$close',start_time='2010-01-01', end_time='2020-12-31', freq='day')
x2=h.expression(instrument='SH600000', field='$close',start_time='2010-01-01', end_time='2020-12-31', freq='day')
print(Abs(x1),Sign(x1),Log(x1),Lt(x1,x2))
print('---------9CacheAPI-----------------')
print('---------------10Datasetapi-----------')
from qlib.data.dataset import processor
import pandas as pd
x=pd.read_csv('1.csv',index_col="datetime")
y=processor.get_group_columns(x,'return')
print(y)#我常见得这个封装很奇怪
from qlib.data.dataset.processor import Processor
xx=Processor()
print(xx.fit(x))
from qlib.contrib import evaluate
print(evaluate.risk_analysis(x['return'], N=5))
现在开始二轮学习,看着官方给的例程逐行研究!
对了,补充一下jupyter notebook的三个快捷键:dd,shift+enter,ctrl+shift+-,分别是删,跑,分