import talib
import matplotlib.pyplot as plt
from mplfinance.original_flavor import candlestick_ohlc
import matplotlib.ticker as ticker
import tushare as ts
pro = ts.pro_api('')
data = pro.daily(**{
"ts_code": "000001.SZ",
"trade_date": "",
"start_date": 20220101,
"end_date": "20220601",
"offset": "",
"limit": ""
}, fields=[
"trade_date",
"open",
"high",
"low",
"close",
"vol",
"pre_close"
])
data["macd"], data["macd_signal"], data["macd_hist"] = talib.MACD(data['close'])
data["ma7"] = talib.MA(data["close"], timeperiod=7)
data["ma8"] = talib.MA(data["close"], timeperiod=8)
data["ma10"] = talib.MA(data["close"], timeperiod=10)
data["ma25"] = talib.MA(data["close"], timeperiod=25)
data["ma30"] = talib.MA(data["close"], timeperiod=30)
data["rsi"] = talib.RSI(data["close"])
print(data)
fig = plt.figure()
fig.set_size_inches((20, 16))
ax_candle = fig.add_axes((0, 0.72, 1, 0.32))
ax_macd = fig.add_axes((0, 0.48, 1, 0.2), sharex=ax_candle)
ax_rsi = fig.add_axes((0, 0.24, 1, 0.2), sharex=ax_candle)
ax_vol = fig.add_axes((0, 0, 1, 0.2), sharex=ax_candle)
ohlc = []
row_number = 0
for date, row in data.iterrows():
date, openp, highp, lowp, closep = row[:5]
ohlc.append([row_number, openp, highp, lowp, closep])
row_number = row_number + 1
date_tickers = data.trade_date.values
def format_date(x, pos=None):
if x < 0 or x > len(date_tickers) - 1:
return ''
return date_tickers[int(x)]
ax_candle.plot(data.index, data["ma7"], label="MA7")
ax_candle.plot(data.index, data["ma8"], label="MA8")
ax_candle.plot(data.index, data["ma25"], label="MA25")
candlestick_ohlc(ax_candle, ohlc, colorup="g", colordown="r", width=0.8)
ax_candle.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
ax_candle.xaxis.set_major_locator(ticker.MultipleLocator(6))
ax_candle.grid(True)
ax_candle.set_title("title", fontsize=20)
ax_candle.legend()
ax_macd.plot(data.index, data["macd"], label="macd")
ax_macd.bar(data.index, data["macd_hist"] * 3, label="hist")
ax_macd.plot(data.index, data["macd_signal"], label="signal")
ax_macd.set_title('MACD')
ax_macd.legend()
ax_rsi.set_ylabel("(%)")
ax_rsi.plot(data.index, [70] * len(data.index), label="overbought")
ax_rsi.plot(data.index, [30] * len(data.index), label="oversold")
ax_rsi.plot(data.index, data["rsi"], label="rsi")
ax_rsi.set_title('KDJ')
ax_rsi.legend()
ax_vol.bar(data.index, data["vol"] / 1000000)
ax_vol.set_ylabel("(Million)")
plt.show()