官网的例子:
http://hmmlearn.readthedocs.io/en/latest/auto_examples/plot_hmm_stock_analysis.html#sphx-glr-auto-examples-plot-hmm-stock-analysis-py
但是已经过期了 这个代码无法用
要 安装一个库
sudo pip install fix_yahoo_finance
There's also a pip package fix-yahoo-finance (sic!)
If you install this package, something like that should do the trick:
import fix_yahoo_finance as yf
quotes = yf.download("INTC", datetime.date(1995, 1, 1), datetime.date(2012, 1, 6))
quotes_matrix = quotes.reset_index().as_matrix()
However, I think we should include the data into the package, just like scikit has its' toy datasets. Otherwise such issues will be reoccurring.@superbobrywhat do you think? From what I see you're the maintainer of hmmlearn?
我改动之后
# coding: utf-8
# In[1]:
from __future__ import print_function
import datetime
import numpy as np
from matplotlib import cm, pyplot as plt
from matplotlib.dates import YearLocator, MonthLocator
try:
from matplotlib.finance import quotes_historical_yahoo_ochl
except ImportError:
# For Matplotlib prior to 1.5.
from matplotlib.finance import (
quotes_historical_yahoo as quotes_historical_yahoo_ochl
)
from hmmlearn.hmm import GaussianHMM
print(__doc__)
import fix_yahoo_finance as yf
# In[2]:
quotes = yf.download("INTC", datetime.date(1995, 1, 1), datetime.date(2012, 1, 6))
quotes_matrix = quotes.reset_index().as_matrix()
# In[12]:
print(quotes)
quotes = np.array(quotes)
for q in quotes:
open = np.array(q[0])
close = np.array(q[3])
volume=np.array(q[5])
print(open,close,volume)
#获取到开盘价 收盘价 和 交易体量
# In[19]:
open = np.array(quotes[:,0])
close = np.array(quotes[:,3])
volume=np.array(quotes[:,5])
print(open)
# In[22]:
x=np.column_stack([close,volume])
print(x)
# In[25]:
#以上是用numpy获取数据 ,以下是用HMM训练 运行高斯HMM
print("fitting to HMM and decoding ...", end="")
#创建一个HMM实例并执行fit
model = GaussianHMM(n_components=4,covariance_type="diag",n_iter=1000).fit(x)
# In[26]:
#预测内部隐藏状态的最佳顺序
hidden_states=model.predict(x)
# In[27]:
#以下是画图
print("Transition matrix")
print(model.transmat_)
# In[28]:
print("Means and vars of each hidden state")
for i in range(model.n_components):
print("{0}th hidden state".format(i))
print("mean = ",model.means_[i])
print("val = ",np.diag(model.covars_[i]))
print()
# In[30]:
fig , axs = plt.subplots(model.n_components,sharex=True,sharey=True)
colours = cm.rainbow(np.linspace(0,1,model.n_components))
#plt.subplots 有s和没有s 有差别的
# In[34]:
for i ,(ax,colour) in enumerate(zip(axs,colours)):
#使用花哨索引来绘制每个状态的数据
mask = hidden_states == i
ax.plot_date(open[mask],close[mask],".-", c=colour)
ax.set_title("{0}th hidden state".format(i))
# Format the ticks.
ax.xaxis.set_major_locator(YearLocator())
ax.xaxis.set_minor_locator(MonthLocator())
ax.grid(True)
plt.show()
输出
Transition matrix
[[ 9.80610945e-001 5.01334224e-003 7.89925823e-233 1.43757129e-002]
[ 3.00453847e-003 9.94724313e-001 5.06277757e-287 2.27114819e-003]
[ 1.22706578e-240 1.69469602e-106 9.97700148e-001 2.29985206e-003]
[ 1.86991637e-002 6.75484942e-029 1.17468735e-108 9.81300836e-001]]
Means and vars of each hidden state
0th hidden state
mean = [ 2.23179993e+01 6.28104280e+07]
val = [ 2.26268607e+00 4.27252350e+14]
1th hidden state
mean = [ 3.53577227e+01 5.32016661e+07]
val = [ 1.42426609e+02 4.34011309e+14]
2th hidden state
mean = [ 7.55640975e+00 8.17939427e+07]
val = [ 2.51862808e+00 2.20319259e+15]
3th hidden state
mean = [ 1.78359459e+01 7.62517437e+07]
val = [ 4.62819936e+00 1.25113880e+15]