量化投资框架

研究了一段时间的量化投资,现在已经从入门到入土了。总结了一下量化投资编程的框架,共享给大家。框架并不是太精细或者完全正确。希望参考此框架的同学可以自己再整理一下。另外跟随视频写的一个交易程序也放进来啦。

具体的视频课程为:金融量化分析入门 - 网易云课堂
https://study.163.com/course/introduction/1004577035.htm?share=1&shareId=1142835134#/courseDetail?tab=1

下面是框架图:

量化投资框架_第1张图片

下面是程序:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tushare as ts
import datetime
import dateutil
import time

pro = ts.pro_api('b8611f97d18eca1905fe0a55f0f887938ed9b3afee177c4c703bec94')
trade_cal = pro.trade_cal()
trade_cal.to_csv('trade_cal.csv')
trade_cal = pd.read_csv('trade_cal.csv')


CASH = 100000
START_DATE = 20160206 
END_DATE = 20180504


class Context:
    def __init__(self,cash,start_date,end_date):
        self.cash = cash
        self.start_date = start_date
        self.end_date = end_date
        self.positions = {} #持仓数量
        self.benchmark = None #基准红线
        self.date_range = trade_cal[(trade_cal['cal_date'] <= end_date) & \
                                    (trade_cal['cal_date'] >= start_date) & \
                                    (trade_cal['is_open'] == 1)]['cal_date'].values
        self.dt = None


context = Context(CASH, START_DATE, END_DATE)

def set_benchmark(security):
    context.benchmark = security
    

#空类
class G:
    pass
g = G()


#获取历史行情数据
def attribute_history(security,count,fields=('open','close','high','low','volume')):
    end_date = context.dt - datetime.timedelta(days=1)
    str_end_date = str(end_date.strftime("%Y%m%d"))
    int_end_date = int(str_end_date)
    start_date = trade_cal[(trade_cal['is_open'] == 1) & (trade_cal['cal_date'] <= int_end_date)].iloc[-count]['cal_date']
    return attribute_daterange_history(security,start_date,int_end_date,fields=('open','close','high','low'))

def attribute_daterange_history(security,start_date,end_date,fields=('open','close','high','low')):
    try:
        f = open(security+'.csv',mode='r')
        daily_cal = pd.read_csv(security+'.csv',index_col='trade_date')
        daily_cal = daily_cal.loc[end_date:start_date,['open','close','high','low']]  
    except FileNotFoundError:
        daily = pro.daily(ts_code=security)
        daily.to_csv(security+'.csv')
        daily_cal = pd.read_csv(security+'.csv',index_col='trade_date')
        daily_cal = daily_cal.loc[end_date:start_date,['open','close','high','low']]
    return daily_cal


#下单函数
#获取今日数据
def get_today_data(security):
    today = context.dt.strftime('%Y%m%d')
    today = int(str(today))
    try:
        f = open(security+'.csv',mode='r')
        today_data = pd.read_csv(security+'.csv',index_col='trade_date').loc[today,['open','close','high','low']]
    except FileNotFoundError:        
        today_data = pro.daily(ts_code=security)
        today_data.to_csv(security+'.csv')
        today_data = pd.read_csv(security+'.csv',index_col='trade_date')
        today_data = today_data.loc[today,['open','close','high','low']]
    except KeyError:
        today_data = None
    return today_data

#今日交易股数
def _order(today_data,security,amount):
    p = today_data['open']

    if len(today_data) == 0:
        print("今日停牌")
        return
    
    if context.cash - amount * p < 0:
        amount = int(context.cash / p)
        print("现金不足,已经调整为%d" % (amount))

    if amount % 100 != 0:
        if amount != -context.positions.get(security,0):
            amount = int(amount/100)*100
            print("不是100的倍数,已调整为%d" % amount)
            
    if context.positions.get(security,0) < -amount:
        amount = -context.positions.get(security,0)
        print("卖出股数超过持有数,已调整为持股数%d" % amount)

    context.positions[security] = context.positions.get(security,0) + amount

    context.cash -= amount * p

    if context.positions[security] == 0:
        del context.positions[security]


def order(security,amount):
    today_data = get_today_data(security)
    _order(today_data,security,amount)

def target_order(security,amount):
    if amount < 0:
        amount = 0
        print("数量不能为负,已调整为0")

    today_data = get_today_data(security)
    hold_amount = context.positions.get(security,0) #todo:T+1
    delta_amount = amount - hold_amount
    _order(today_data,security,delta_amount)

def order_value(security,value):
    today_data = get_today_data(security)
    amount = int(value / today_data['open'])
    _order(today_data,security,amount)

def target_order_value(security,value):
    today_data = get_today_data(security)
    if value < 0:
        value = 0
        print("资产为负,已调整为0")

    hold_value = context.positions.get(security,0) * today_data['open']
    delta_value = hold_value - value
    order_value(security,delta_value)


#回测
def run():
    plt_df = pd.DataFrame(index=context.date_range,columns=['own_value'])
    last_price = pd.DataFrame(index=context.positions,columns=['price'])
    init_cash = context.cash
    initialize(context)

    for dt in context.date_range:
        context.dt = dateutil.parser.parse(str(dt))
        handle_data(context)
        own_value = context.cash
        
        for stock in context.positions:
            today_data = get_today_data(stock)
            if len(today_data) == 0:
                p = last_price[stock,'price']
            else:
                p = today_data['open']
                last_price[stock,'price'] = p

            own_value += context.positions[stock] * p
        
        plt_df.loc[dt,'own_value'] = own_value

    plt_df['ratio'] = (plt_df['own_value'] - init_cash) / init_cash

    bm_df = attribute_daterange_history(context.benchmark,context.start_date,context.end_date)
    bm_init = bm_df['open'].iloc[0]
    plt_df['benchmark_ratio'] = (bm_df['open'] - bm_init) / bm_init
    plt_df[['ratio','benchmark_ratio']].plot()
    plt.show()

    
def initialize(context):
    set_benchmark('000001.SZ')
    g.p1 = 5
    g.p2 = 10
    g.security = '000001.SZ'


def handle_data(context):
    hist = attribute_history(g.security,g.p2)
    ma5 = hist['close'][-g.p1:].mean()
    ma60 = hist['close'].mean()

    if ma5 > ma60 and g.security not in context.positions:
        order_value(g.security,context.cash)
    elif ma5 <= ma60 and g.security in context.positions:
        target_order(g.security,0)

run()

 

你可能感兴趣的:(nothing)