利用Tushare将股票数据写入MySql数据库

一、使用的工具

a、SQLAlchemy。 b 、MySql。 c、python3.7

二、学习资料

SQLAlchemy

https://www.osgeo.cn/sqlalchemy/orm/tutorial.html#connecting             官网教程

https://www.cnblogs.com/Zzbj/p/10212279.htmlhttps://www.cnblogs.com/Zzbj/p/10212279.html

MySql

https://www.runoob.com/mysql/mysql-tutorial.html

三、过程简述

1、SQLAlchemy介绍

QLAlchemy是一个基于Python的ORM框架。该框架是建立在DB-API(DB-API是Python的数据库接口规范)之上,使用关系对象映射进行数据库操作。简而言之就是,将类和对象转换成SQL,然后使用数据API执行SQL并获取执行结果。

利用Tushare将股票数据写入MySql数据库_第1张图片

2、创建表

3、使用to_sql函数,写入数据

四、代码

from sqlalchemy import Column,String,Float,Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine
import time
import tushare as ts
import pandas as pd
import datetime

Base = declarative_base()
class StockBasic(Base):
    """股票列表
    is_hs	    str	N	是否沪深港通标的,N否 H沪股通 S深股通
    list_status	str	N	上市状态: L上市 D退市 P暂停上市
    exchange	str	N	交易所 SSE上交所 SZSE深交所 HKEX港交所(未上线)
    """
    __tablename__ = 'stock_basic'
    
    ts_code = Column(String(10), primary_key=True)  # TS代码
    symbol = Column(String(10))         # 股票代码
    name = Column(String(10))           # 股票名称
    area = Column(String(4))            # 所在地域
    industry = Column(String(4))        # 所属行业
    fullname = Column(String(30))       # 股票全称
    enname = Column(String(100))        # 英文全称
    market = Column(String(3))          # 市场类型 (主板/中小板/创业板)
    exchange = Column(String(4))        # 交易所代码
    curr_type = Column(String(3))       # 交易货币
    list_status = Column(String(1))     # 上市状态: L上市 D退市 P暂停上市
    list_date = Column(String(8))       # 上市日期
    delist_date = Column(String(8))     # 退市日期
    is_hs = Column(String(1))           # 是否沪深港通标的,N否 H沪股通 S深股通

class TradeCal(Base):
    '''交易日历
    exchange     str  N  交易所 SSE上交所,SZSE深交所,CFFEX 中金所,SHFE 上期所,CZCE 郑商所,DCE 大商所,INE 上能源,IB 银行间,XHKG 港交所
    start_date   str  N  开始日期
    end_date     str  N  结束日期
    is_open      str  N  是否交易 '0'休市 '1'交易
    '''    
    __tablename__ = 'trade_cal'
    
    cal_date = Column(String(8), primary_key=True)      #日历日期
    exchange = Column(String(4))                        #交易所 SSE上交所 SZSE深交所    
    is_open = Column(String(1))                         #是否交易 0休市 1交易
    
class Daily(Base):
    """日线行情
    ts_code	    str	N	股票代码(二选一)
    trade_date	str	N	交易日期(二选一)
    start_date	str	N	开始日期(YYYYMMDD)
    end_date	str	N	结束日期(YYYYMMDD)
    """
    __tablename__ = 'daily'
    
    ts_code = Column(String(10), primary_key=True)      # 股票代码
    trade_date = Column(String(8), primary_key=True)    # 交易日期
    open = Column(Float)        # 开盘价
    high = Column(Float)        # 最高价
    low = Column(Float)         # 最低价
    close = Column(Float)       # 收盘价
    pre_close = Column(Float)   # 昨收价
    change = Column(Float)      # 涨跌额
    pct_chg = Column(Float)     # 涨跌幅 (未复权,如果是复权请用 通用行情接口 )
    vol = Column(Float)         # 成交量 (手)
    amount = Column(Float)      # 成交额 (千元)

class DailyBasic(Base):
    '''每日指标
    ts_code      str  股票代码(二选一)
    trade_date   str  交易日期 (二选一)
    start_date   str  开始日期(YYYYMMDD)
    end_date     str  结束日期(YYYYMMDD)
    '''
    __tablename__ = 'daily_basic'   
    
    ts_code = Column(String(10), primary_key=True)      # 股票代码
    trade_date = Column(String(8), primary_key=True)    # 交易日期
    close = Column(Float)                               # 当日收盘价
    turnover_rate = Column(Float)                       # 换手率(%)
    turnover_rate_f = Column(Float)                     # 换手率(自由流通股)
    volume_ratio = Column(Float)                        # 量比
    pe = Column(Float)                                  # 市盈率(总市值/净利润)
    pe_ttm = Column(Float)                              # 市盈率(TTM)
    pb = Column(Float)                                  # 市净率(总市值/净资产)
    ps = Column(Float)                                  # 市销率
    ps_ttm = Column(Float)	                            # 市销率(TTM)
    dv_ratio = Column(Float)	                        # 股息率 (%)
    dv_ttm = Column(Float)		                        # 股息率(TTM)(%)
    total_share = Column(Float)	  	                    # 总股本 (万股)
    float_share = Column(Float)		                    # 流通股本 (万股)
    free_share = Column(Float)		                    # 自由流通股本 (万)
    total_mv = Column(Float)		                    # 总市值 (万元)
    circ_mv = Column(Float)	                            # 流通市值(万元)

class IndexBasic(Base):
    '''指数基本信息
    market	    str	Y	交易所或服务商
    publisher	str	N	发布商
    category	str	N	指数类别
    '''
    __tablename__ = 'index_basic'
    
    ts_code = Column(String(10),primary_key=True)              # TS代码
    name = Column(Text)                                  # 简称
    fullname = Column(Text)                              # 指数全称
    market = Column(Text)                                # 市场
    publisher = Column(Text)                             # 发布方
    index_type = Column(Text)                            # 指数风格
    category = Column(Text)                              # 指数类别
    base_date = Column(Text)                             # 基期
    base_point = Column(Float)                           # 基点
    list_date = Column(Text)                             # 发布日期
    weight_rule = Column(Text)                           # 加权方式
    desc = Column(Text)                                  # 描述
    exp_date = Column(Text)                              # 终止日期

class IndexDaily(Base):
    '''指数日线行情
    ts_code	    str	  Y	指数代码
    trade_date	str	  N	交易日期 (日期格式:YYYYMMDD,下同)
    start_date	str   N	开始日期
    end_date	None  N	结束日期
    '''
    __tablename__ = 'index_daily'
    
    ts_code = Column(String(10), primary_key=True)      # TS指数代码
    trade_date = Column(String(8), primary_key=True)    # 交易日
    close = Column(Float)	                            # 收盘点位
    open = Column(Float)	                            # 开盘点位
    high = Column(Float)	                            # 最高点位
    low = Column(Float)	                                # 最低点位
    pre_close = Column(Float)	                        # 昨日收盘点
    change = Column(Float)	                            # 涨跌点
    pct_chg = Column(Float)	                            # 涨跌幅(%)
    vol = Column(Float)	                                # 成交量(手)
    amount = Column(Float)	                            # 成交额(千元)

class IndexWeight(Base):
    '''指数成分和权重
    index_code	str	 Y	指数代码 (二选一)
    trade_date	str	 Y	交易日期 (二选一)
    start_date	str	 N	开始日期
    end_date	None N	结束日期
    '''    
    __tablename__ = 'index_weight'
     
    index_code = Column(String(10),primary_key=True)		# 指数代码
    con_code = Column(String(10),primary_key=True)		    # 成分代码
    trade_date = Column(String(8),primary_key=True)	        # 交易日期
    weight = Column(Float)		                            # 权重

class IndexDailybasic(Base):
    '''大盘指数每日指标
    trade_date	str	N	交易日期 (格式:YYYYMMDD,比如20181018,下同)
    ts_code	    str	N	TS代码
    start_date	str	N	开始日期
    end_date	str	N	结束日期   
    '''    
    __tablename__ = 'index_dailybasic'
    
    ts_code = Column(String(10), primary_key=True)		        # TS代码
    trade_date = Column(String(8), primary_key=True)	        # 交易日期
    total_mv = Column(Float)		                            # 当日总市值(元)
    float_mv = Column(Float)		                            # 当日流通市值(元)
    total_share = Column(Float)		                            # 当日总股本(股)
    float_share = Column(Float)		                            # 当日流通股本(股)
    free_share = Column(Float)		                            # 当日自由流通股本(股)
    turnover_rate = Column(Float)		                        # 换手率
    turnover_rate_f = Column(Float)		                        # 换手率(基于自由流通股本)
    pe = Column(Float)		                                    # 市盈率
    pe_ttm = Column(Float)		                                # 市盈率TTM
    pb = Column(Float)		                                    # 市净率

class IndexClassify(Base):
    '''申万行业分类
    index_code	str	N	指数代码
    level	    str	N	行业分级(L1/L2/L3)
    src	        str N	指数来源(SW申万)
    '''     
    __tablename__ = 'index_classify'
    
    index_code = Column(String(10), primary_key=True)		    # 指数代码
    industry_name = Column(String(8))		                    # 行业名称
    level = Column(String(8))		                            # 行业名称
    industry_code = Column(String(8))		                    # 行业代码
    src = Column(String(8))		                                # 行业分类(SW申万)

class IndexMember(Base):
    '''申万行业成分构成
    index_code	str	N	指数代码
    ts_code	    str	N	股票代码
    is_new	    str	N	是否最新(默认为“Y是”)
    '''
    __tablename__ = 'index_member'
    
    index_code = Column(String(10), primary_key=True)		    # 指数代码
    index_name = Column(String(8))		                        # 指数名称
    con_code = Column(String(8))		                        # 成分股票代码
    con_name = Column(String(8))		                        # 成分股票名称
    in_date = Column(String(8))		                            # 纳入日期
    out_date = Column(String(8))		                        # 剔除日期
    is_new = Column(String(8))		                            # 是否最新Y是N否
    
class OpeateData():
    '''
    数据库数据写入,修改
    '''
    def __init__(self):
        ......
        
    def get_end_dt(self):
        time_temp = datetime.datetime.now() - datetime.timedelta(days=1)
        end_dt = time_temp.strftime('%Y%m%d')  
        return end_dt
    
    def create_db(self):
        # metadata.create_all创建所有表
        Base.metadata.create_all(self.engine)
    
    def drop_db(self):
        # metadata.drop_all删除所有表
        Base.metadata.drop_all(self.engine)
        
#------------------------------------------------------- 股票列表   
    def update_stock_basic(self,engine, pro, retry_count, pause):       
        """更新 股票信息 所有数据"""
        def get_stock_basic(pro, retry_count=3, pause=2):
            """股票列表 数据"""
            frame = pd.DataFrame()
            for status in ['L', 'D', 'P']:
                for _ in range(retry_count):
                    try:
                        df = pro.stock_basic(exchange='', list_status=status,
                                             fields='ts_code,symbol,name,area,industry,fullname,enname,market, \
                                            exchange,curr_type,list_status,list_date,delist_date,is_hs')
                    except:
                        time.sleep(pause)
                    else:
                        frame=pd.concat([frame,df])
                        break
            frame.reset_index(drop=True,inplace=True)
            return frame
        
        data = get_stock_basic(pro, retry_count, pause)
        data.to_sql('stock_basic', engine, if_exists='replace', index=False)

#---------------------------------------------------------- 交易日历
    def update_trade_cal(self,engine, pro, start_dt, end_dt, retry_count, pause):
        '''交易日历更新'''
        def get_trade_cal(pro, start_dt, end_dt, retry_count=3, pause=2):
            '''获取交易日历'''
            for _ in range(retry_count):
                try:
                    df = pro.trade_cal(exchange='', start_date=start_dt, end_date=end_dt ,fields=['exchange','cal_date','is_open'])
                except:
                    time.sleep(pause)
                else:
                    break
            return df
        df = get_trade_cal(pro, start_dt, end_dt, retry_count, pause)
        df.to_sql('trade_cal', engine, if_exists='append', index=False)
        
#---------------------------------------------------------- 日线行情    
    def get_ts_code(self):
        """查询ts_code"""
        sql = 'select ts_code from stock_basic'
        return pd.read_sql(sql, self.engine)
    
    def delete_daily(self,engine, start_dt, end_dt):
        """删除 日线行情 数据"""
        conn = engine.connect()
        conn.execute('delete from daily where  trade_date between ' + start_dt + ' and ' + end_dt)

    def update_daily(self,engine, pro, start_dt, end_dt, retry_count, pause):
        codes = self.get_ts_code()
        """获取日线行情 数据"""
        def get_daily(pro, ts_code, start_dt, end_dt, retry_count=3, pause=2):
            """股票代码方式获取 日线行情 数据"""
            for _ in range(retry_count):
                try:
                    df = pro.daily(ts_code=ts_code, start_date=start_dt, end_date=end_dt,
                                   fields='ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount')
                except:
                    time.sleep(pause)
                else:
                    break
            return df

        for value in codes['ts_code']:
            df = get_daily(pro, value, start_dt, end_dt, retry_count, pause)
            df.to_sql('daily', engine, if_exists='append', index=False)
            time.sleep(0.5)    

#-------------------------------------------------------------每日指标
    def delete_daily_basic(self,engine, start_dt, end_dt):
        """删除 每日指标 数据"""
        conn = engine.connect()
        conn.execute('delete from daily_basic where  trade_date between ' + start_dt + ' and ' + end_dt)   
    
    def update_daily_basic(self,engine, pro, start_dt, end_dt, retry_count, pause):
        codes = self.get_ts_code() 
        def get_daily_basic(pro, ts_code, start_dt, end_dt, retry_count=3, pause=2):
            """获取每日指标 数据"""
            for _ in range(retry_count):
                try:
                    df =  pro.daily_basic(ts_code=ts_code, start_date=start_dt ,end_date=end_dt,
                                          fields='ts_code,trade_date,close,turnover_rate,turnover_rate_f,volume_ratio,pe,pe_ttm,pb,ps,ps_ttm,dv_ratio,dv_ttm,total_share,float_share,free_share,total_mv,circ_mv')
                except:
                    time.sleep(pause)
                else:
                    return df 
        for value in codes['ts_code']:
            df = get_daily_basic(pro, value, start_dt, end_dt, retry_count, pause)
            df.to_sql('daily_basic', engine, if_exists='append', index=False)
            time.sleep(0.5)         

#-------------------------------------------------------------指数基本信息
    def update_index_basic(self,engine, pro, retry_count, pause):  
        '''更新指数基本信息'''
        def get_index_basic(pro, retry_count=3, pause=2):
            '''获取指数基本信息'''
            index=['MSCI','CSI','SSE','SZSE','CICC','SW','OTH']
            frame = pd.DataFrame()
            for i in index:
                for _ in range(retry_count):
                    try:
                        df=pro.index_basic(market=i,fields='ts_code,name,fullname,market,publisher,index_type,category,base_date,base_point,list_date,weight_rule,desc,exp_date')
                    except:
                        time.sleep(pause)
                    else:
                        frame=pd.concat([frame,df])
                        break
            return frame
        data = get_index_basic(pro, retry_count, pause)
        data.to_sql('index_basic',engine, if_exists='replace', index=False)
        
#-------------------------------------------------------------指数日线行情
    def delete_index_daily(self,engine, start_dt, end_dt):
        """删除 每日指标 数据"""
        conn = engine.connect()
        conn.execute('delete from index_daily where  trade_date between ' + start_dt + ' and ' + end_dt)

    def update_index_daily(self,engine, pro, start_dt, end_dt, retry_count, pause):
        '''更新 每日指标数据 '''
        def get_index_daily(pro, ts_code, start_dt, end_dt, retry_count=3, pause=2):
            """获取指数日线行情 数据"""
            for _ in range(retry_count):
                try:
                    df = pro.index_daily(ts_code=ts_code, start_date=start_dt, end_date=end_dt,
                            fields=['ts_code','trade_date','close','open','high','low','pre_close','change','pct_chg','vol','amount'])
                except:
                    time.sleep(pause)
                else:
                    break
            return df
        codes = ['000001.SH','000300.SH','000905.SH','399001.SZ','399005.SZ','399006.SZ','399016.SZ','399300.SZ','000005.SH', '000006.SH','000016.SH']
        for value in codes:
            df = get_index_daily(pro, value, start_dt, end_dt, retry_count, pause)
            df.to_sql('index_daily', engine, if_exists='append', index=False)
            time.sleep(0.5)

#--------------------------------------------------------------指数成分和权重    
    def update_index_weight(self,engine ,pro ,start_dt, end_dt, retry_count=3, pause=2):
        ''' 更新 指数成分和权重 数据 '''
        def get_index_weight(pro, index_code, start_dt, end_dt, retry_count=3, pause=2):
            """获取指数成分和权重 数据"""
            for _ in range(retry_count):
                try:
                    df = pro.index_weight(index_code=index_code, start_date=start_dt, end_date=end_dt,
                            fields=['index_code', 'con_code', 'trade_date', 'weight'])
                except:
                    time.sleep(pause)
                else:
                    break
            return df
        codes = ['000001.SH','000300.SH','000905.SH','399001.SZ','399005.SZ','399006.SZ','399016.SZ','399300.SZ','000005.SH', '000006.SH','000016.SH']
        for value in codes:
            df = get_index_weight(pro, value, start_dt, end_dt, retry_count, pause)
            df.to_sql('index_weight', engine, if_exists='append', index=False)
            time.sleep(0.5) 

#--------------------------------------------------------------大盘指数每日指标    
    def update_index_dailybasic(self,engine ,pro ,start_dt, end_dt, retry_count=3, pause=2):
        ''' 更新 大盘指数每日指标 数据 '''
        def get_index_dailybasic(pro, ts_code, start_dt, end_dt, retry_count=3, pause=2):
            """获取大盘指数每日指标 数据"""
            for _ in range(retry_count):
                try:
                    df = pro.index_dailybasic(ts_code=ts_code, start_date=start_dt, end_date=end_dt,
                             fields='ts_code,trade_date,total_mv,float_mv,total_share,float_share,free_share,turnover_rate,turnover_rate_f,pe,pe_ttm,pb')
                except:
                    time.sleep(pause)
                else:
                    break
            return df
        codes = ['000001.SH','000300.SH','000905.SH','399001.SZ','399005.SZ','399006.SZ','399016.SZ','399300.SZ','000005.SH', '000006.SH','000016.SH']
        for value in codes:
            df = get_index_dailybasic(pro, value, start_dt, end_dt, retry_count, pause)
            df.to_sql('index_dailybasic', engine, if_exists='append', index=False)
            time.sleep(0.5)

#-------------------------------------------------------------申万行业分类
    def update_index_classify(self,engine, pro, retry_count, pause):       
        """更新 申万行业分类 所有数据"""
        def get_index_classify(pro, retry_count=3, pause=2):
            """申万行业分类 数据"""
            frame = pd.DataFrame()
            for status in ['L1', 'L2', 'L3']:
                for _ in range(retry_count):
                    try:
                        df = pro.index_classify(level=status, src='SW', fields='index_code,industry_name,level,industry_code,src')
                    except:
                        time.sleep(pause)
                    else:
                        frame=pd.concat([frame,df])
                        break
            frame.reset_index(drop=True,inplace=True)
            return frame
        
        data = get_index_classify(pro, retry_count, pause)
        data.to_sql('index_classify', engine, if_exists='replace', index=False)

#-------------------------------------------------------------申万行业成分构成
    def update_index_member(self,engine, pro, retry_conut, pause):
        '''更新 申万行业构成 '''
        def get_index_member(pro, retry_conut, pause):
            '''申万行业成分构成 数据'''
            sql = "select index_code from index_classify where level='L1'"
            index_code = pd.read_sql(sql, engine)
            frame = pd.DataFrame()
            for i in index_code['index_code']:
                for _ in range(retry_conut):
                    try:
                        df = pro.index_member(index_code=i,fields ='index_code,index_name,con_code,con_name,in_date,out_date,is_new')
                    except:
                        time.sleep(pause)
                    else:
                        frame=pd.concat([frame,df])
                        break
            return frame        
        data = get_index_member(pro, retry_conut, pause)
        data.to_sql('index_member', engine, if_exists='replace', index=False)
            
#-------------------------------------------------------------全部更新            
    def update_all(self,last_date):
        '''
        全部更新。记录上一次更新时间'20200101'
        
        参数
        -------------------
        last_date ; str 上一次更新的时间
        '''
        time_temp = datetime.datetime.now() - datetime.timedelta(days=1)
        end_dt = time_temp.strftime('%Y%m%d')
        
        time_temp = datetime.datetime.strptime(last_date, "%Y%m%d") + datetime.timedelta(days=1)
        start_dt = time_temp.strftime('%Y%m%d')
        if self.end_dt ==end_dt:
            self.update_stock_basic(self.engine, self.pro, 3, 2)
            self.update_trade_cal(self.engine, self.pro, start_dt , end_dt, 3, 2)
            self.update_daily(self.engine, self.pro, start_dt , end_dt, 3, 2)
            self.update_daily_basic(self.engine, self.pro, start_dt , end_dt, 3, 2)
            self.update_index_basic(self.engine, self.pro, 3, 2)
            self.update_index_daily(self.engine, self.pro, start_dt, end_dt, 3, 2)
            self.update_index_weight(self.engine, self.pro, start_dt, end_dt, 3, 2)
            self.update_index_dailybasic(self.engine, self.pro, start_dt, end_dt, 3, 2)
            self.update_index_classify(self.engine, self.pro, 3, 2)
            self.update_index_member(self.engine, self.pro, 3, 2)
            

 

你可能感兴趣的:(mysql)