基于tushare的A股市场行情维护程序

基于tushare的A股市场行情维护程序

  • 1 开发背景:
  • 2 tushare 简介
  • 3 功能需求
    • 3.1 tushare数据获取接口封装
    • 3.2 A股全市场股票日线数据的批量下载和更新
  • 4 软件设计
  • 5 程序实现
    • 5.1 AshareDailyData.py
    • 5.2 TuShare.py
  • 6 运行界面截图
  • 7 参考资料

1 开发背景:

笔者刚接触量化投资,对量化投资挺感兴趣,在闲暇时间进行量化投资的学习,只能进行少量资金进行量化实践。目前在进行基于vnpy的A股市场的量化策略学习,主要尝试攻克的技术难点在:A股市场日线数据的免费获取维护、自动下单交易、全市场选股程序、选股策略的回测程序、基于机器学习的股票趋势预测。
欢迎志同道合的朋友加我QQ(1163962054)交流。
tushare ID:237684。
github仓库:https://github.com/PanAndy/quant_share

2 tushare 简介

tushare是一个基于Python的金融数据接口,拥有丰富的数据内容,如股票、基金、期货、数字货币等行情数据,也有公司财务、基金经理等基本面数据等。特别重要的是,tushare提供的数据是免费的!!!个人开发需要的是A股日线数据,所以tushare是首选。

3 功能需求

3.1 tushare数据获取接口封装

  1. tushare初始化参数相关设置
  2. 参考rqdata模块的代码,实现适用于vnpy的tushare历史行情接口
  3. 历史数据获取过程中,注意考虑tushare每次获取数据上限的规则

3.2 A股全市场股票日线数据的批量下载和更新

  1. 能获取A股全市场股票代码
  2. 能获取A股所有交易日
  3. 按tushare的规则进行日线数据批量下载,存储到sqlite数据库中
  4. 每日定时更新股票日线数据

4 软件设计

基于tushare的A股市场行情维护程序_第1张图片

5 程序实现

5.1 AshareDailyData.py

import multiprocessing
import os
import sys
import traceback
from datetime import datetime, timedelta, time
from time import sleep

from tqdm import tqdm
from vnpy.trader.constant import Interval
from vnpy.trader.database import database_manager
from vnpy.trader.object import HistoryRequest

from utils import log

sys.path.append(os.getcwd())

from TuShare import tushare_client, to_split_ts_codes, TS_DATE_FORMATE


class AShareDailyDataManager:

    def __init__(self):
        """"""
        self.tushare_client = tushare_client
        self.symbols = None
        self.trade_cal = None
        self.init()

    def init(self):
        """"""
        self.tushare_client.init()
        self.symbols = self.tushare_client.symbols
        self.trade_cal = self.tushare_client.trade_cal

    def download_all(self):
        """
        使用tushare下载A股股票全市场日线数据
        :return:
        """
        log.info("开始下载A股股票全市场日线数据")
        if self.symbols is not None:
            with tqdm(total=len(self.symbols)) as pbar:
                for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']):
                    symbol, exchange = to_split_ts_codes(tscode)

                    pbar.set_description_str("下载A股日线数据股票代码:" + tscode)
                    start_date = datetime.strptime(list_date, TS_DATE_FORMATE)
                    req = HistoryRequest(symbol=symbol,
                                         exchange=exchange,
                                         start=start_date,
                                         end=datetime.now(),
                                         interval=Interval.DAILY)
                    bardata = self.tushare_client.query_history(req=req)

                    if bardata:
                        try:
                            database_manager.save_bar_data(bardata)
                        except Exception as ex:
                            log.error(tscode + "数据存入数据库异常")
                            log.error(ex)
                            traceback.print_exc()

                    pbar.update(1)
                    log.info(pbar.desc)

        log.info("A股股票全市场日线数据下载完毕")

    def update_newest(self):
        """
        使用tushare更新本地数据库中的最新数据,默认本地数据库中原最新的数据之前的数据都是完备的
        :return:
        """
        log.info("开始更新最新的A股股票全市场日线数据")
        if self.symbols is not None:
            with tqdm(total=len(self.symbols)) as pbar:
                for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']):
                    symbol, exchange = to_split_ts_codes(tscode)

                    newest_local_bar = database_manager.get_newest_bar_data(symbol=symbol,
                                                                            exchange=exchange,
                                                                            interval=Interval.DAILY)
                    if newest_local_bar is not None:
                        pbar.set_description_str("正在处理股票代码:" + tscode + "本地最新数据:" +
                                                 newest_local_bar.datetime.strftime(TS_DATE_FORMATE))
                        start_date = newest_local_bar.datetime + timedelta(days=1)
                    else:
                        pbar.set_description_str("正在处理股票代码:" + tscode + "无本地数据")
                        start_date = datetime.strptime(list_date, TS_DATE_FORMATE)
                    req = HistoryRequest(symbol=symbol,
                                         exchange=exchange,
                                         start=start_date,
                                         end=datetime.now(),
                                         interval=Interval.DAILY)
                    bardata = self.tushare_client.query_history(req=req)
                    if bardata:
                        try:
                            database_manager.save_bar_data(bardata)
                        except Exception as ex:
                            log.error(tscode + "数据存入数据库异常")
                            log.error(ex)
                            traceback.print_exc()

                    pbar.update(1)
                    log.info(pbar.desc)

        log.info("A股股票全市场日线数据更新完毕")

    def check_update_all(self):
        """
        这个方法太慢了,不建议调用。
        这个方法用于本地数据库已经建立,但可能有部分数据缺失时使用
        使用tushare检查更新所有的A股股票全市场日线数据
        检查哪一个交易日的数据是缺失的,补全它
        检查上市后是否每个交易日都有数据,若存在某一交易日无数据,尝试从tushare查询该日数据,若仍无,则说明当天停盘
        :return:
        """
        log.info("开始检查更新所有的A股股票全市场日线数据")

        if self.symbols is not None:
            with tqdm(total=len(self.symbols)) as pbar:
                for tscode, list_date in zip(self.symbols['ts_code'], self.symbols['list_date']):
                    pbar.set_description_str("正在检查A股日线数据,股票代码:" + tscode)

                    symbol, exchange = to_split_ts_codes(tscode)

                    local_bar = database_manager.load_bar_data(symbol=symbol,
                                                               exchange=exchange,
                                                               interval=Interval.DAILY,
                                                               start=datetime.strptime(list_date, TS_DATE_FORMATE),
                                                               end=datetime.now())
                    local_bar_dates = [bar.datetime.strftime(TS_DATE_FORMATE) for bar in local_bar]

                    index = (self.trade_cal[exchange.value][(self.trade_cal[exchange.value].cal_date == list_date)])
                    trade_cal = self.trade_cal[exchange.value].iloc[index.index[0]:]
                    for trade_date in trade_cal['cal_date']:
                        if trade_date not in local_bar_dates:
                            req = HistoryRequest(symbol=symbol,
                                                 exchange=exchange,
                                                 start=datetime.strptime(trade_date, TS_DATE_FORMATE),
                                                 end=datetime.strptime(trade_date, TS_DATE_FORMATE),
                                                 interval=Interval.DAILY)
                            bardata = self.tushare_client.query_history(req=req)
                            if bardata:
                                log.info(tscode + "本地数据库缺失:" + trade_date)
                                try:
                                    database_manager.save_bar_data(bardata)
                                except Exception as ex:
                                    log.error(tscode + "数据存入数据库异常")
                                    log.error(ex)
                                    traceback.print_exc()
                    pbar.update(1)
                    log.info(pbar.desc)

        log.info("A股股票全市场日线数据检查更新完毕")


a_share_daily_data_manager = AShareDailyDataManager()


def auto_update(start_time: time = time(18, 0)):
    """
    每日盘后自动更新最新日线数据到本地数据库
    """
    log.info("启动A股股票全市场日线数据定时更新")
    run_parent(start_time=start_time)


def run_parent(start_time: time = time(18, 0)):
    """
    运行父进程,定时启动子进程下载任务
    :return:
    """
    log.info("启动A股股票全市场日线数据定时更新父进程")

    # 每天晚上18:30从tushare更新当时K线数据
    UPDATE_TIME = start_time

    child_process = None

    while True:
        current_time = datetime.now().time()

        if current_time.hour == UPDATE_TIME.hour and current_time.minute == UPDATE_TIME.minute and child_process is None:
            log.info("启动日线数据更新子进程")
            child_process = multiprocessing.Process(target=run_child)
            child_process.start()
            log.info("日线数据更新子进程启动成功")

        if (not (current_time.hour == UPDATE_TIME.hour and current_time.minute == UPDATE_TIME.minute)) \
                and child_process is not None:
            child_process.join()
            child_process = None
            log.info("数据更新子进程关闭成功")
            log.info("进入A股股票全市场日线数据定时更新父进程")

        sleep(10)


def run_child():
    """
    子线程下载数据
    :return:
    """
    log.info("启动A股股票全市场日线数据定时更新子进程")

    try:
        a_share_daily_data_manager.update_newest()
    except Exception:
        log.info("子进程异常")
        traceback.print_exc()


if __name__ == '__main__':
    log.info("自动更新A股股票全市场日线数据")

    # a_share_daily_data_manager.download_all()
    # a_share_daily_data_manager.update_newest()
    # a_share_daily_data_manager.check_update_all()
    auto_update(start_time=time(18, 00))

5.2 TuShare.py

import requests
import tushare as ts
from tushare.pro import client
from pytz import timezone
from typing import List, Optional, Dict
import pandas as pd
from datetime import datetime, timedelta
import time
import traceback

from vnpy.trader.object import HistoryRequest, BarData
from vnpy.trader.constant import Exchange, Interval

from utils import log

CHINA_TZ = timezone("Asia/Shanghai")

tushare_token: str = ""

MAX_QUERY_SIZE: int = 5000
TS_DATE_FORMATE: str = '%Y%m%d'
MAX_QUERY_TIMES: int = 500

EXCHANGE_TS2VT: Dict[str, Exchange] = {
    'SH': Exchange.SSE,
    'SZ': Exchange.SZSE
}

EXCHANGE_VT2TS: Dict[Exchange, str] = {v: k for k, v in EXCHANGE_TS2VT.items()}


def to_ts_symbol(symbol: str, exchange: Exchange):
    """
    转换合约代码为tushare查询代码
    """
    if exchange == Exchange.SSE:
        tcode = f'{symbol}' + '.' + f'{EXCHANGE_VT2TS[exchange]}'
    elif exchange == Exchange.SZSE:
        tcode = f'{symbol}' + '.' + f'{EXCHANGE_VT2TS[exchange]}'
    else:
        print("目前只研究深圳证券交易所和上海证券交易所A股股票!")
        raise TypeError("目前只研究深圳证券交易所和上海证券交易所A股股票!")
    return tcode


def to_split_ts_codes(tscode: str):
    symbol, exchange_ts = tscode.split('.')
    exchange = EXCHANGE_TS2VT[exchange_ts]
    return symbol, exchange


class TuShareClient:
    """
    从TuShare中查询历史数据的Client
    tushare日线数据说明:交易日每天15点~16点之间更新数据,daily接口是未复权行情,停牌期间不提供数据。
    tushare调取说明:基础积分每分钟内最多调取500次,每次5000条数据
    """

    def __init__(self):
        """"""

        self.pro: client.DataApi = None

        self.inited: bool = False

        # 获得所有股票代码
        self.symbols: pd.DataFrame = None

        # 获得交易日历
        self.trade_cal: Dict[str, pd.DataFrame] = None

    def init(self, token: str = "") -> bool:
        """"""
        if self.inited:
            return True

        if token:
            ts.set_token(tushare_token)
        else:
            ts.set_token(tushare_token)

        try:
            self.pro = ts.pro_api()
            self.stock_list()
            self.trade_day_list()
        except (BaseException, "tushare连接失败"):
            return False

        self.inited = True
        return True

    def query_history(self, req: HistoryRequest) -> Optional[List[BarData]]:
        """
        从tushare里查询历史数据
        :param req:查询请求
        :return: Optional[List[BarData]]
        """
        if self.symbols is None:
            return None

        symbol = req.symbol
        exchange = req.exchange
        interval = req.interval
        start = req.start.strftime(TS_DATE_FORMATE)
        end = req.end.strftime(TS_DATE_FORMATE)

        if interval is not Interval.DAILY:
            return None
        if exchange not in [Exchange.SSE, Exchange.SZSE]:
            return None

        tscode = to_ts_symbol(symbol, exchange)

        # 修改查询数据逻辑,在每次5000条数据的限制下,很可能一次无法读取完
        cnt = 0
        df: pd.DataFrame = None
        while datetime.strptime(start, TS_DATE_FORMATE) <= datetime.strptime(end, TS_DATE_FORMATE):
            # 保证每次查询最多5000天数据
            start_date = datetime.strptime(start, TS_DATE_FORMATE)
            simulate_end_date = min(datetime.strptime(end, TS_DATE_FORMATE),
                                    start_date + timedelta(days=MAX_QUERY_SIZE))
            simulate_end = simulate_end_date.strftime(TS_DATE_FORMATE)

            # 保证每次调用时间在60/500=0.12秒内,以保证每分钟调用次数少于500次
            # begin_time = time.time()
            tushare_df = None
            while True:
                try:
                    tushare_df = self.pro.query('daily', ts_code=tscode, start_date=start, end_date=simulate_end)
                except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
                    log.error(e)
                    # traceback.print_exc()
                    # ('Connection aborted.', ConnectionResetError(10054, '远程主机强迫关闭了一个现有的连接。', None, 10054, None))
                    if '10054' in str(e):
                        sleep_time = 60.0
                        log.info("请求过于频繁,sleep:" + str(sleep_time) + "s")
                        time.sleep(sleep_time)
                        log.info("继续发送请求:" + tscode)
                        continue  # 继续发请求
                    else:
                        raise Exception(e)  # 其他异常,抛出来
                break
            if tushare_df is not None:
                if df is None:
                    df = tushare_df
                else:
                    df = pd.concat([df, tushare_df], ignore_index=True)
            # end_time = time.time()
            # delta = round(end_time - begin_time, 3)
            # if delta < 60 / MAX_QUERY_TIMES:
            sleep_time = 0.5
            log.info("sleep:" + str(sleep_time) + "s")
            time.sleep(sleep_time)

            cnt += 1
            start = (simulate_end_date + timedelta(days=1)).strftime(TS_DATE_FORMATE)

        data: List[BarData] = []

        if df is not None:
            for ix, row in df.iterrows():
                date = datetime.strptime(row.trade_date, '%Y%m%d')
                date = CHINA_TZ.localize(date)

                if pd.isnull(row['open']):
                    log.info(symbol + '.' + EXCHANGE_VT2TS[exchange] + row['trade_date'] + "open_price为None")
                elif pd.isnull(row['high']):
                    log.info(symbol + '.' + EXCHANGE_VT2TS[exchange] + row['trade_date'] + "high_price为None")
                elif pd.isnull(row['low']):
                    log.info(symbol + '.' + EXCHANGE_VT2TS[exchange] + row['trade_date'] + "low_price为None")
                elif pd.isnull(row['close']):
                    log.info(symbol + '.' + EXCHANGE_VT2TS[exchange] + row['trade_date'] + "close_price为None")
                elif pd.isnull(row['amount']):
                    log.info(symbol + '.' + EXCHANGE_VT2TS[exchange] + row['trade_date'] + "volume为None")

                row.fillna(0)
                bar = BarData(
                    symbol=symbol,
                    exchange=exchange,
                    interval=interval,
                    datetime=date,
                    open_price=row['open'],
                    high_price=row['high'],
                    low_price=row['low'],
                    close_price=row['close'],
                    volume=row['amount'],
                    gateway_name='tushare'
                )

                data.append(bar)
        return data

    def stock_list(self):
        """
        调用tushare stock_basic 接口
        获得上海证券交易所和深圳证券交易所所有股票代码
        获取基础信息数据,包括股票代码、名称、上市日期、退市日期等
        :return:
        """
        if self.symbols is None:
            symbols_sse = self.pro.query('stock_basic', exchange=Exchange.SSE.value, fields='ts_code,symbol,name,'
                                                                                            'fullname,enname,market,'
                                                                                            'list_status,list_date,'
                                                                                            'delist_date,is_hs')
            symbols_szse = self.pro.query('stock_basic', exchange=Exchange.SZSE.value, fields='ts_code,symbol,name,'
                                                                                              'fullname,enname,market,'
                                                                                              'list_status,list_date,'
                                                                                              'delist_date,is_hs')
            self.symbols = pd.concat([symbols_sse, symbols_szse], axis=0, ignore_index=True)

    def trade_day_list(self):
        """
        查询交易日历
        :return:
        """
        if self.trade_cal is None:
            self.trade_cal = dict()
            self.trade_cal[Exchange.SSE.value] = self.pro.query('trade_cal', exchange=Exchange.SSE.value, is_open='1')
            self.trade_cal[Exchange.SZSE.value] = self.pro.query('trade_cal', exchange=Exchange.SZSE.value, is_open='1')


tushare_client = TuShareClient()

if __name__ == "__main__":
    print("测试TuShare数据接口")
    # tushare_client = TuShareClient()
    tushare_client.init()
    # print(tushare_client.symbols)
    # print(tushare_client.trade_cal)

    req = HistoryRequest(symbol='600600', exchange=Exchange.SSE,
                         start=datetime(year=1999, month=11, day=10), end=datetime.now(), interval=Interval.DAILY)

    ts_data = tushare_client.query_history(req)
    print(len(ts_data))

5.3 utils.py

import logging


class logger:
    def __init__(self, path, clevel=logging.INFO, Flevel=logging.INFO):
        self.logger = logging.getLogger(path)
        self.logger.setLevel(logging.DEBUG)
        fmt = logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s', '%Y-%m-%d %H:%M:%S')
        # 设置CMD日志
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel(clevel)
        # 设置文件日志
        fh = logging.FileHandler(path, encoding='utf-8')
        fh.setFormatter(fmt)
        fh.setLevel(Flevel)
        self.logger.addHandler(sh)
        self.logger.addHandler(fh)

    def debug(self, message):
        self.logger.debug(message)

    def info(self, message):
        self.logger.info(message)

    def war(self, message):
        self.logger.warn(message)

    def error(self, message):
        self.logger.error(message)

    def cri(self, message):
        self.logger.critical(message)


log = logger("log.txt")

6 运行界面截图

基于tushare的A股市场行情维护程序_第2张图片

7 参考资料

  1. 全市场期货数据的批量下载和更新
  2. 使用免费的天勤SDK数据,替换付费的RQData
  3. vnpy不使用rqdata,尝试tushare
  4. tushare文档

你可能感兴趣的:(量化投资,tushare,vnpy,A股数据维护)