本文将对backtrader的仓位管理进行介绍,具体以同时回测交易3只股票为例,查看每日仓位情况。
买入条件:5日线金叉60日线
卖出条件:5日线死叉60日线
仓位信息输出的核心代码位于策略类的next的方法中:
def next(self):
for i, d in enumerate(self.datas):
dt, dn = self.datetime.date(), d._name # 获取时间及股票代码
pos = self.getposition(d)
if not len(pos): # 不在场内,则可以买入
if self.inds[d]['cross'] > 0: # 如果金叉
self.buy(data = d, size = self.p.pstake) # 买买买
elif self.inds[d]['cross'] < 0: # 在场内,且死叉
self.close(data = d) # 卖卖卖
# 打印仓位信息
print('**************************************************************', file = self.log_file)
print(self.data.datetime.date(), file = self.log_file)
for i, d in enumerate(self.datas):
pos = self.getposition(d)
if len(pos):
print('{}, 持仓:{}, 成本价:{}, 当前价:{}, 盈亏:{:.2f}'.format(
d._name, pos.size, pos.price, pos.adjbase, pos.size * (pos.adjbase - pos.price)),
file = self.log_file)
部分输出结果为:
...
*************************************************************************************************
2019-06-24
000001, 持仓:100, 成本价:13.69, 当前价:13.69, 盈亏:0.00
*************************************************************************************************
2019-06-25
000001, 持仓:100, 成本价:13.69, 当前价:13.43, 盈亏:-26.00
*************************************************************************************************
2019-06-26
000001, 持仓:100, 成本价:13.69, 当前价:13.37, 盈亏:-32.00
*************************************************************************************************
2019-06-27
000001, 持仓:100, 成本价:13.69, 当前价:13.71, 盈亏:2.00
*************************************************************************************************
2019-06-28
000001, 持仓:100, 成本价:13.69, 当前价:13.78, 盈亏:9.00
*************************************************************************************************
2019-07-01
000001, 持仓:100, 成本价:13.69, 当前价:13.93, 盈亏:24.00
*************************************************************************************************
2019-07-02
000001, 持仓:100, 成本价:13.69, 当前价:14.18, 盈亏:49.00
*************************************************************************************************
2019-07-03
000001, 持仓:100, 成本价:13.69, 当前价:14.01, 盈亏:32.00
*************************************************************************************************
2019-07-04
000001, 持仓:100, 成本价:13.69, 当前价:13.99, 盈亏:30.00
*************************************************************************************************
2019-07-05
000001, 持仓:100, 成本价:13.69, 当前价:13.92, 盈亏:23.00
000002, 持仓:100, 成本价:29.45, 当前价:29.45, 盈亏:0.00
*************************************************************************************************
2019-07-08
000001, 持仓:100, 成本价:13.69, 当前价:13.59, 盈亏:-10.00
000002, 持仓:100, 成本价:29.45, 当前价:29.15, 盈亏:-30.00
...
上面的代码用到了多只股票同时进行策略回测,具体内容可以参见笔记(17)。
在backtrader中,使用类Position来管理仓位。Position的重要属性包括:
price:资产成本单价
size:仓位大小
adjbase:资产当前收盘价格
几点说明:
在类Strategy中,可以使用属性position或者方法getposition(self, data=None, broker=None)来访问仓位信息。
当使用getposition(self, data=None, broker=None)方法时,如果不指定参数data值,该方法会默认返回datas[0]的仓位信息,若指定参数data,则返回对应资产的仓位信息,如示例所示。
for i, d in enumerate(self.datas):
pos = self.getposition(d)
仓位信息positions实际保存在代理broker中,positions是一个Python的字典,key是一个种子数据Data Feed,value是一个类Position的对象。也就是说,positions[datas[0]]就表示第一只股票datas[0]对应的仓位信息,positions[datas[1]]就表示第二只股票datas[1]对应的仓位信息,依次类推。
类Position重写了__len__方法,可以使用len(position)来判断仓位大小是否为0。
类Position重写了__str__方法,可以通过print(position)来打印仓位的具体信息。
仓位信息示例代码:
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import datetime # 用于datetime对象操作
import os.path # 用于管理路径
import sys # 用于在argvTo[0]中找到脚本名称
import backtrader as bt # 引入backtrader框架
import pandas as pd
stk_num = 3 # 回测股票数目
# 创建策略
class SmaCross(bt.Strategy):
# 可配置策略参数
params = dict(
pfast=5, # 短期均线周期
pslow=60, # 长期均线周期
poneplot = False, # 是否打印到同一张图
pstake = 100 # 单笔交易股票数目
)
def __init__(self):
self.log_file = open('position_log.txt', 'w') # 用于输出仓位信息
self.inds = dict()
for i, d in enumerate(self.datas):
self.inds[d] = dict()
self.inds[d]['sma1'] = bt.ind.SMA(d.close, period=self.p.pfast) # 短期均线
self.inds[d]['sma2'] = bt.ind.SMA(d.close, period=self.p.pslow) # 长期均线
self.inds[d]['cross'] = bt.ind.CrossOver(self.inds[d]['sma1'], self.inds[d]['sma2'], plot = False) # 交叉信号
# 跳过第一只股票data,第一只股票data作为主图数据
if i > 0:
if self.p.poneplot:
d.plotinfo.plotmaster = self.datas[0]
def next(self):
for i, d in enumerate(self.datas):
dt, dn = self.datetime.date(), d._name # 获取时间及股票代码
pos = self.getposition(d)
if not len(pos): # 不在场内,则可以买入
if self.inds[d]['cross'] > 0: # 如果金叉
self.buy(data = d, size = self.p.pstake) # 买买买
elif self.inds[d]['cross'] < 0: # 在场内,且死叉
self.close(data = d) # 卖卖卖
# 打印仓位信息
print('*****************************************************************************', file = self.log_file)
print(self.data.datetime.date(), file = self.log_file)
for i, d in enumerate(self.datas):
pos = self.getposition(d)
if len(pos):
print('{}, 持仓:{}, 成本价:{}, 当前价:{}, 盈亏:{:.2f}'.format(
d._name, pos.size, pos.price, pos.adjbase, pos.size * (pos.adjbase - pos.price)),
file = self.log_file)
def stop(self):
self.log_file.close()
pass
cerebro = bt.Cerebro() # 创建cerebro
# 读入股票代码
stk_code_file = '../TQDat/TQDown2020v1/data/tq_wrk_code2019.csv'
stk_pools = pd.read_csv(stk_code_file, encoding = 'gbk')
if stk_num > stk_pools.shape[0]:
print('股票数目不能大于%d' % stk_pools.shape[0])
exit()
for i in range(stk_num):
stk_code = stk_pools['code'][stk_pools.index[i]]
stk_code = '%06d' % stk_code
# 读入数据
datapath = '../TQDat/day/stk/' + stk_code + '.csv'
# 创建价格数据
data = bt.feeds.GenericCSVData(
dataname = datapath,
fromdate = datetime.datetime(2019, 1, 1),
todate = datetime.datetime(2019, 12, 31),
nullvalue = 0.0,
dtformat = ('%Y-%m-%d'),
datetime = 0,
open = 1,
high = 2,
low = 3,
close = 4,
volume = 5,
openinterest = -1
)
# 在Cerebro中添加股票数据
cerebro.adddata(data, name = stk_code)
# 设置启动资金
cerebro.broker.setcash(100000.0)
# 设置佣金为零
cerebro.broker.setcommission(commission=0.00)
cerebro.addstrategy(SmaCross, poneplot = False) # 添加策略
cerebro.run() # 遍历所有数据
# 打印最后结果
print('Final Portfolio Value: %.2f' % cerebro.broker.getvalue())
cerebro.plot(style = "candlestick") # 绘图