多股回测(backtrader+quantstats+akshare)

导包

#引入技术指标数据
from __future__ import (absolute_import ,division,print_function,unicode_literals)
import datetime #用于datetime对象操作
import os.path  #用于管理路径
import sys      #用于在argvTo[0]中找到脚本名称
import backtrader as bt #引入backtrader框架
%matplotlib inline

策略

#创建策略
class TestStrategy(bt.Strategy):
    params = (
        ('maperiod1',5),
        ('maperiod2',13),
        ('maperiod3',21),
        ('maperiod4',34),
        ('maperiod5',55),
        ('printlog',True),
        ('poneplot' , False),#是否打印到同一张图
        ('pstake' , 100000) #单笔交易股票数据
    )
    def log(self,txt,dt=None,doprint = False):
        dt = dt or self.datas[0].datetime.date(0)
        #print('%s,%s' % (dt.isoformat(),txt))

        """策略的日志函数"""
        if self.params.printlog or doprint:
            dt = dt or self.datas[0].datetime.date(0)
            print('%s,%s' % (dt.isoformat(),txt))
    def __init__(self):
        self.inds = dict()
        for i, d in enumerate(self.datas):
            self.inds[d] = dict()
            self.inds[d]['ma1'] = bt.indicators.SimpleMovingAverage( d.close,period = self.params.maperiod1)
            self.inds[d]['ma2'] = bt.indicators.SimpleMovingAverage( d.close,period = self.params.maperiod2)
            self.inds[d]['ma3'] = bt.indicators.SimpleMovingAverage( d.close,period = self.params.maperiod3)
            self.inds[d]['ma4'] = bt.indicators.SimpleMovingAverage( d.close,period = self.params.maperiod4)
            self.inds[d]['ma5'] = bt.indicators.SimpleMovingAverage( d.close,period = self.params.maperiod5)
            self.inds[d]['D1'] = bt.ind.CrossOver(self.inds[d]['ma5'],self.inds[d]['ma4']) #交叉信号
            self.inds[d]['A1'] = bt.ind.CrossOver(self.inds[d]['ma1'],self.inds[d]['ma2']) #交叉信号   
            self.inds[d]['C1'] = bt.ind.CrossOver(self.inds[d]['ma2'],self.inds[d]['ma3'])
            #跳过第一只股票data,第一只股票data作为主图数据
            if i > 0:
                if self.p.poneplot:
                    d.plotinfo.plotmaster = self.datas[0]
    def notify_trade(self,trade):
        if not trade.isclosed:
            return
        self.log('OPERATION PROFIT,GROSS %.2F,NET %.2F' %
                (trade.pnl,trade.pnlcomm))
    #多股回测时使用,数据读取。 
    def prenext(self):
        self.next()
    def next(self):
        # 获取当天日期
        date = self.datas[0].datetime.date(0)
        # 获取当天value
        value = self.broker.getvalue()
        for i , d in enumerate(self.datas):            
            dt,dn = self.datetime.date(),d._name             #获取时间及股票代码        
            pos = self.getposition(d).size 
            sig1 = ((self.inds[d]['D1'][-1]>0) and (self.inds[d]['A1'][0]>0)) and (self.inds[d]['ma2'][0] >self.inds[d]['ma4'][0])and (self.inds[d]['ma4'][0] >self.inds[d]['ma4'][-1])
            sig2 = ((self.inds[d]['D1'][-1]>0)  or (self.inds[d]['A1'][0]>0 ))and(self.inds[d]['ma2'][0] >self.inds[d]['ma2'][-1])and(d.close[0]/d.open[0]>1.05)and(d.volume[0] /d.volume[-1]>2)
            sig3 = ((self.inds[d]['D1'][-1]>0)  or (self.inds[d]['A1'][0]>0 ))and(self.inds[d]['ma2'][0] >self.inds[d]['ma3'][0] )and(self.inds[d]['ma3'][0] >self.inds[d]['ma4'][0] )and(self.inds[d]['ma4'][0] >self.inds[d]['ma4'][-1] )
            sig4 = self.inds[d]['C1'][0]<0
            #print('sig1',sig1)
            if not pos:                                      # 不在场内,则可以买入  vol成交量, ref日前
                if sig1 or sig2 and sig3: #如果金叉
                    self.buy(data =d,size =self.p.pstake)    #买
                    self.log('%s,BUY CREATE, %.2f ,%s' % (dt, d.close[0] ,d._name))
                    #self.order = self.buy()
            elif sig4:              #在场内。且死叉
                self.close(data = d)                     #卖
                self.log('%s,SELL CREATE,%.2f,%s' % (dt, d.close[0] ,d._name))
                #self.order = self.sell()

印花税

class stampDutyCommissionScheme(bt.CommInfoBase):
    params = (
        ('stamp_duty',0.005),#印花税率
        ('percabs',True),
    )
    def _gotcommission(self,size,price,pseudoexec):
        if size >0:#买入,不考虑印花税
            return size*price * self.p.commission
        elif size<0:#卖出,考虑印花税
            return -size*price*(self.p.stamp_duty + self.p.commission)
        else:
            return 0

开始回测

#创建cerebro实体
cerebro = bt.Cerebro()
#添加策略
cerebro.addstrategy(TestStrategy)

添加数据

#创建价格数据
import akshare as ak
import baostock as bs
import pandas as pd
import datetime

#获取股票池数据
from os import listdir
filename = listdir('D:/stock_data')
stk_pools = filename

for i in stk_pools[:]:
  
    try:
        datapath = 'D:/stock_data/'+i
        df = pd.read_csv('D:/stock_data/'+i)
        #将数据长度不足的股票删去
        if len(df)<55:
            pass
        else:
            try:
                data = bt.feeds.GenericCSVData(
                    dataname = datapath,
                    fromdate = datetime.datetime(2010,4,1),
                    todate = datetime.datetime(2021,7,8),
                    nullvalue = 0.0,
                    dtformat = ('%Y-%m-%d'),
                    datetime = 1,
                    open =2,
                    high = 3,
                    low = 4,
                    close = 5,
                    volume = 6,
                    openinterest = -1
                    )
                cerebro.adddata(data,name = i)
            except:
                continue 
    except:
        continue

 设置参数

#设置启动资金
cerebro.broker.setcash(len(stk_pools[:50])*10000)
#设置交易单位大小
cerebro.addsizer(bt.sizers.FixedSize,stake = 100)
#设置佣金为千分之一
comminfo = stampDutyCommissionScheme(stamp_duty=0.005,commission=0.001)
cerebro.broker.addcommissioninfo(comminfo)
#不显示曲线
for d in cerebro.datas:
    d.plotinfo.plot = False
#打印开始信息
print('Starting Portfolio Value: %.2f' % cerebro.broker.getvalue())

回测数据分析

#查看策略效果
cerebro.addanalyzer(bt.analyzers.PyFolio, _name='pyfolio')
back  = cerebro.run(maxcpus=12,exactbars=True,stdstats=False)


import warnings
warnings.filterwarnings('ignore')
strat = back[0]
portfolio_stats = strat.analyzers.getbyname('pyfolio')
returns, positions, transactions, gross_lev = portfolio_stats.get_pf_items()
returns.index = returns.index.tz_convert(None)



import quantstats
quantstats.reports.html(returns, output='stats.html', title='Stock Sentiment')



import webbrowser
f = webbrowser.open('stats.html')
#打印最后结果
print('Final Profolio Value : %.2f' %cerebro.broker.getvalue())

你可能感兴趣的:(量化)