研究了一段时间的量化投资,现在已经从入门到入土了。总结了一下量化投资编程的框架,共享给大家。框架并不是太精细或者完全正确。希望参考此框架的同学可以自己再整理一下。另外跟随视频写的一个交易程序也放进来啦。
具体的视频课程为:金融量化分析入门 - 网易云课堂
https://study.163.com/course/introduction/1004577035.htm?share=1&shareId=1142835134#/courseDetail?tab=1
下面是框架图:
下面是程序:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tushare as ts
import datetime
import dateutil
import time
pro = ts.pro_api('b8611f97d18eca1905fe0a55f0f887938ed9b3afee177c4c703bec94')
trade_cal = pro.trade_cal()
trade_cal.to_csv('trade_cal.csv')
trade_cal = pd.read_csv('trade_cal.csv')
CASH = 100000
START_DATE = 20160206
END_DATE = 20180504
class Context:
def __init__(self,cash,start_date,end_date):
self.cash = cash
self.start_date = start_date
self.end_date = end_date
self.positions = {} #持仓数量
self.benchmark = None #基准红线
self.date_range = trade_cal[(trade_cal['cal_date'] <= end_date) & \
(trade_cal['cal_date'] >= start_date) & \
(trade_cal['is_open'] == 1)]['cal_date'].values
self.dt = None
context = Context(CASH, START_DATE, END_DATE)
def set_benchmark(security):
context.benchmark = security
#空类
class G:
pass
g = G()
#获取历史行情数据
def attribute_history(security,count,fields=('open','close','high','low','volume')):
end_date = context.dt - datetime.timedelta(days=1)
str_end_date = str(end_date.strftime("%Y%m%d"))
int_end_date = int(str_end_date)
start_date = trade_cal[(trade_cal['is_open'] == 1) & (trade_cal['cal_date'] <= int_end_date)].iloc[-count]['cal_date']
return attribute_daterange_history(security,start_date,int_end_date,fields=('open','close','high','low'))
def attribute_daterange_history(security,start_date,end_date,fields=('open','close','high','low')):
try:
f = open(security+'.csv',mode='r')
daily_cal = pd.read_csv(security+'.csv',index_col='trade_date')
daily_cal = daily_cal.loc[end_date:start_date,['open','close','high','low']]
except FileNotFoundError:
daily = pro.daily(ts_code=security)
daily.to_csv(security+'.csv')
daily_cal = pd.read_csv(security+'.csv',index_col='trade_date')
daily_cal = daily_cal.loc[end_date:start_date,['open','close','high','low']]
return daily_cal
#下单函数
#获取今日数据
def get_today_data(security):
today = context.dt.strftime('%Y%m%d')
today = int(str(today))
try:
f = open(security+'.csv',mode='r')
today_data = pd.read_csv(security+'.csv',index_col='trade_date').loc[today,['open','close','high','low']]
except FileNotFoundError:
today_data = pro.daily(ts_code=security)
today_data.to_csv(security+'.csv')
today_data = pd.read_csv(security+'.csv',index_col='trade_date')
today_data = today_data.loc[today,['open','close','high','low']]
except KeyError:
today_data = None
return today_data
#今日交易股数
def _order(today_data,security,amount):
p = today_data['open']
if len(today_data) == 0:
print("今日停牌")
return
if context.cash - amount * p < 0:
amount = int(context.cash / p)
print("现金不足,已经调整为%d" % (amount))
if amount % 100 != 0:
if amount != -context.positions.get(security,0):
amount = int(amount/100)*100
print("不是100的倍数,已调整为%d" % amount)
if context.positions.get(security,0) < -amount:
amount = -context.positions.get(security,0)
print("卖出股数超过持有数,已调整为持股数%d" % amount)
context.positions[security] = context.positions.get(security,0) + amount
context.cash -= amount * p
if context.positions[security] == 0:
del context.positions[security]
def order(security,amount):
today_data = get_today_data(security)
_order(today_data,security,amount)
def target_order(security,amount):
if amount < 0:
amount = 0
print("数量不能为负,已调整为0")
today_data = get_today_data(security)
hold_amount = context.positions.get(security,0) #todo:T+1
delta_amount = amount - hold_amount
_order(today_data,security,delta_amount)
def order_value(security,value):
today_data = get_today_data(security)
amount = int(value / today_data['open'])
_order(today_data,security,amount)
def target_order_value(security,value):
today_data = get_today_data(security)
if value < 0:
value = 0
print("资产为负,已调整为0")
hold_value = context.positions.get(security,0) * today_data['open']
delta_value = hold_value - value
order_value(security,delta_value)
#回测
def run():
plt_df = pd.DataFrame(index=context.date_range,columns=['own_value'])
last_price = pd.DataFrame(index=context.positions,columns=['price'])
init_cash = context.cash
initialize(context)
for dt in context.date_range:
context.dt = dateutil.parser.parse(str(dt))
handle_data(context)
own_value = context.cash
for stock in context.positions:
today_data = get_today_data(stock)
if len(today_data) == 0:
p = last_price[stock,'price']
else:
p = today_data['open']
last_price[stock,'price'] = p
own_value += context.positions[stock] * p
plt_df.loc[dt,'own_value'] = own_value
plt_df['ratio'] = (plt_df['own_value'] - init_cash) / init_cash
bm_df = attribute_daterange_history(context.benchmark,context.start_date,context.end_date)
bm_init = bm_df['open'].iloc[0]
plt_df['benchmark_ratio'] = (bm_df['open'] - bm_init) / bm_init
plt_df[['ratio','benchmark_ratio']].plot()
plt.show()
def initialize(context):
set_benchmark('000001.SZ')
g.p1 = 5
g.p2 = 10
g.security = '000001.SZ'
def handle_data(context):
hist = attribute_history(g.security,g.p2)
ma5 = hist['close'][-g.p1:].mean()
ma60 = hist['close'].mean()
if ma5 > ma60 and g.security not in context.positions:
order_value(g.security,context.cash)
elif ma5 <= ma60 and g.security in context.positions:
target_order(g.security,0)
run()