量化分析之(三)均线突破平台处理

突破60、120、250日均线平台,周期可选

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @license : (C) Copyright 2017-2020
# @Time    : 2020/6/16 9:43
# @File    : lesson3.py
# @Software: PyCharm
# @desc    :

import sys, copy
import numpy as np
from gpx.common.gpc import *
import progressbar as pb
from datetime import datetime, timedelta

logger = logging.getLogger('seq.' + __name__)

PROFIT_COLUMN = ['code', 'name', 'area', 'industry', 'roe',
                 'business_income', 'gross_profit_rate', 'profit', 'npr']

PERIOD_DOTS = {'period': gpc_period.DAY_30, 'dots': gpc_dots.FOUR}

THRESHOLD_PERIOD = 60


def save(df, name):
    return df.to_csv(name, encoding='utf_8_sig')


def retrieve(name):
    return pd.read_csv(name, encoding='utf_8_sig', index_col=0, dtype=object)


class gpc_line():
    def __init__(self, period_dots=PERIOD_DOTS):
        super(gpc_line, self).__init__()
        self._mean_df = pd.DataFrame()
        self._period_dots = period_dots

    def get_period(self):
        return self._period_dots['period']

    def get_dots(self):
        return self._period_dots['dots']

    def get_path(self, mode=gpc_handler.KEEP_UP):
        return f'mean_{self.get_period()}.csv'

    def filter_handler(self, in_code, handler_mode):
        period = self.get_period()
        all_code = copy.copy(in_code)
        ma_x = 'ma' + str(period)
        is_ma_x = 'is_' + ma_x
        all_code.loc[:, ma_x] = 0
        all_code.loc[:, is_ma_x] = False
        index, max_num = 0, len(all_code)
        logger.info(f'start, number = {max_num}, period = {period}')
        with pb.ProgressBar(max_value=max_num) as bar:
            for idx in all_code.index:
                try:
                    code = all_code.loc[idx, 'code']
                    k_data = ts.get_k_data(code)
                    if k_data is None: continue
                    if len(k_data.index) > period:
                        all_code.loc[idx, [is_ma_x, ma_x]] = \
                            self.break_out_platform(code_serial=all_code.loc[idx],
                                                    data=k_data,
                                                    threshold=period)
                except Exception as e:
                    logger.exception(e)
                bar.update(index)
                index += 1
        self._mean_df = all_code[all_code[is_ma_x]]
        if len(self._mean_df):
            logger.info(f'end, the result is {len(self._mean_df)}')
            save(self._mean_df, self.get_path(handler_mode))
            return self._mean_df

    def filter_result(self, all_code, mode=gpc_handler.KEEP_UP):
        period = self.get_period()
        over_ma_file = self.get_path(mode)
        logger.debug(f'mean path: {over_ma_file}')
        if os.path.exists(over_ma_file):
            code_df = retrieve(over_ma_file)
        else:
            code_df = self.filter_handler(all_code, handler_mode=mode)
        if code_df is None:
            logger.warning(f'ma, period = {period}, it do not found.')
        else:
            self.display_code(code_df)
        return code_df

    def display_code(self, loop_df):
        print('please input stocks list: ')
        _code_ = ''
        for idx in loop_df.index:
            _code_ += loop_df.loc[idx, 'code'] + ','
        print(f'\n{_code_}\n')

    def break_out_platform(self, code_serial, data, end_date=None, threshold=60):
        k_data = data
        if len(data) < threshold:
            logger.debug("{0}:Sample less than {1} days...".format(code_serial['code'], threshold))
            return False, 0
        ma_x = 'ma' + str(threshold)
        data[ma_x] = pd.Series(tl.MA(data['close'].values, threshold),
                               index=data.index.values)

        begin_date = data.iloc[0].date
        if end_date is not None:
            if end_date < begin_date:
                logger.debug("{} is not listed at {}".format(code_serial['code'], end_date))
                return False, 0

        if end_date is not None:
            mask = (data['date'] <= end_date)
            data = data.loc[mask]

        data = data.tail(n=threshold)

        breakthrough_row = None

        for index, row in data.iterrows():
            if row['open'] < row[ma_x] <= row['close']:
                if gpc_calc_volume_ratio(k_data, code_serial,
                                         is_amount_filter=True,
                                         threshold=5,
                                         is_logger=False):
                    breakthrough_row = row

        if breakthrough_row is None:
            return False, 0

        data_front = data.loc[(data['date'] < breakthrough_row['date'])]
        data_end = data.loc[(data['date'] >= breakthrough_row['date'])]

        for index, row in data_front.iterrows():
            if not (-0.05 < (row[ma_x] - row['close']) / row[ma_x] < 0.2):
                return False, 0

        logger.warning("{0} {1} break out date: {2}".format(code_serial['code'],
                                                            code_serial['name'],
                                                            breakthrough_row['date']))

        return True, breakthrough_row['date']


if __name__ == '__main__':
    gv = gpc_line()
    df = retrieve(r'stocks.csv')
    gv.filter_result(df)

 测试结果:

[line:82  ] - mean path: mean_30.csv
[line:57  ] - start, number = 141, period = 30

300318 博晖创新 break out date: 2020-06-08

[line:75  ] - end, the result is 1

please input stocks list: 

300318,

 

你可能感兴趣的:(backtrader,tushare,量化投资)