stock_untitl.py
from pymongo import ASCENDING
from database import DB_CONN
from datetime import datetime,timedelta
# In[2]:
def get_trading_dates(begin_date=None, end_date=None):
"""
获取指定日期范围的按照正序排列的交易日列表
如果没有指定日期范围,则获取从当期日期向前365个自然日内的所有交易日
:param begin_date: 开始日期
:param end_date: 结束日期
:return: 日期列表
"""
# 当前日期
now = datetime.now()
# 开始日期,默认今天向前的365个自然日
if begin_date is None:
# 当前日期减去365天
one_year_ago = now - timedelta(days=365)
# 转化为str类型
begin_date = one_year_ago.strftime('%Y-%m-%d')
# 结束日期默认为今天
if end_date is None:
end_date = now.strftime('%Y-%m-%d')
# 用上证综指000001作为查询条件,因为指数是不会停牌的,所以可以查询到所有的交易日
daily_cursor = DB_CONN.daily.find(
{'code': '000001', 'date': {'$gte': begin_date, '$lte': end_date}, 'index': True},
sort=[('date', ASCENDING)],
projection={'date': True, '_id': False})
# 转换为日期列表
dates = [x['date'] for x in daily_cursor]
return dates
def get_all_codes():
"""
获取所有股票代码列表
:return: 股票代码列表
"""
# 通过distinct函数拿到所有不重复的股票代码列表
return DB_CONN.basic.distinct('code')
# In[4]:
if __name__ == '__main__':
get_all_codes()
basic_crawler:
import traceback
from datetime import datetime,timedelta
import tushare as ts
from pymongo import MongoClient
from pandas.io import json
from pymongo import UpdateOne
from stock_util import get_trading_dates
DB_CONN = MongoClient('mongodb://127.0.0.1:27017')['quant_01']
# 从tushare获取股票基础数据,保存到本地的MongoDB数据库中
def crawl_basic(begin_date=None, end_date=None):
"""
抓取指定时间范围内的股票基础信息
:param begin_date: 开始日期
:param end_date: 结束日期
"""
# 如果没有指定开始日期,则默认为前一日
if begin_date is None:
begin_date = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
# 如果没有指定结束日期,则默认为前一日
if end_date is None:
end_date = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
# 获取指定日期范围的所有交易日列表
all_dates = get_trading_dates(begin_date, end_date)
# 按照每个交易日抓取
for date in all_dates:
try:
# 抓取当日的基本信息
crawl_basic_at_date(date)
except:
print('抓取股票基本信息时出错,日期:%s' % date, flush=True)
def crawl_basic_at_date(date):
"""
从Tushare抓取指定日期的股票基本信息
:param date: 日期
"""
# 从TuShare获取基本信息,index是股票代码列表
df_basics = ts.get_stock_basics(date)
# 如果当日没有基础信息,在不做操作
if df_basics is None:
return
# 初始化更新请求列表
update_requests = []
# 获取所有股票代码集合
codes = list(set(df_basics.index)) #codes = list(set(df_basics.index))[:2]
# 按照股票代码提取所有数据
for code in codes:
# 获取一只股票的数据
doc = dict(df_basics.loc[code])
try:
# API返回的数据中,上市日期是一个int类型。将上市日期,20180101转换为2018-01-01的形式
time_to_market = datetime \
.strptime(str(doc['timeToMarket']), '%Y%m%d') \
.strftime('%Y-%m-%d')
# 将总股本和流通股本转为数字
totals = float(doc['totals'])
outstanding = float(doc['outstanding'])
# 组合成基本信息文档
doc.update({
# 股票代码
'code': code,
# 日期
'date': date,
# 上市日期
'timeToMarket': time_to_market,
# 流通股本
'outstanding': outstanding,
# 总股本
'totals': totals
})
# 生成更新请求,需要按照code和date创建索引
# tushare
# numpy.int64/numpy.float64等数据类型,保存到mongodb时无法序列化。
# 解决办法:这里使用pandas.json强制转换成json字符串,然后再转换成dict。int64/float64转换成int,float
update_requests.append(
UpdateOne(
{'code': code, 'date': date},
{'$set': json.loads(json.dumps(doc))}, upsert=True))
except:
print('发生异常,股票代码:%s,日期:%s' % (code, date), flush=True)
print(doc, flush=True)
print(traceback.print_exc())
# 如果抓到了数据
if len(update_requests) > 0:
update_result = DB_CONN['basic'].bulk_write(update_requests, ordered=False)
print('抓取股票基本信息,日期:%s, 插入:%4d条,更新:%4d条' %
(date, update_result.upserted_count, update_result.modified_count), flush=True)
if __name__ == '__main__':
crawl_basic('2017-01-01', '2017-12-31')