数据挖掘 沪深股市预测

导入基本模块库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima_model import ARMA
import warnings
from itertools import product
from datetime import datetime
warnings.filterwarnings('ignore')

加载数据

# 数据加载
df = pd.read_csv('./shanghai_1990-12-19_to_2019-2-28.csv')

将时间作为df的索引

df.Timestamp = pd.to_datetime(df.Timestamp)
df.index = df.Timestamp

效果如图

数据挖掘 沪深股市预测_第1张图片

数据探索

print(df.head())

按照月,季度,年来统计

df_month = df.resample('M').mean()
df_Q = df.resample('Q-DEC').mean()
df_year = df.resample('A-DEC').mean()

按照天,月,季度,年来显示比特币的走势

fig = plt.figure(figsize=[15, 7])
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.suptitle('上证指数', fontsize=20)
plt.subplot(221)
plt.plot(df.Price, '-', label='按天')
plt.legend()
plt.subplot(222)
plt.plot(df_month.Price, '-', label='按月')
plt.legend()
plt.subplot(223)
plt.plot(df_Q.Price, '-', label='按季度')
plt.legend()
plt.subplot(224)
plt.plot(df_year.Price, '-', label='按年')
plt.legend()
plt.show()
数据挖掘 沪深股市预测_第2张图片

ARMA模型训练

  • 设置参数范围
ps = range(0, 3)
qs = range(0, 3)
parameters = product(ps, qs)
parameters_list = list(parameters)
  • 寻找最优ARMA模型参数,即best_aic最小
results = []
best_aic = float("inf") # 正无穷
for param in parameters_list:
    try:
        model = ARMA(df_month.Price,order=(param[0], param[1])).fit()
    except ValueError:
        print('参数错误:', param)
        continue
    aic = model.aic
    if aic < best_aic:
        best_model = model
        best_aic = aic
        best_param = param
    results.append([param, model.aic])
  • 输出最优模型
result_table = pd.DataFrame(results)
result_table.columns = ['parameters', 'aic']
print('最优模型: ', best_model.summary())

指数预测

我们预测今年一年的上证指数走势,使用pd.date_range生成月字段,freq='MS’代表每月开始的日期。

df_month2 = df_month[['Price']]
date_list=pd.date_range('2019-3-31','2019-12-31', freq='M').tolist()
future = pd.DataFrame(index=date_list, columns= df_month.columns)
df_month2 = pd.concat([df_month2, future])
df_month2['forecast'] = best_model.predict(start=0, end=350)

预测结果显示

plt.figure(figsize=(20,7))
df_month2.Price.plot(label='实际指数')
df_month2.forecast.plot(color='r', ls='--', label='预测指数')
plt.legend()
plt.title('指数(月)')
plt.xlabel('时间')
plt.ylabel('指数)
plt.show()

预测结果图片显示
数据挖掘 沪深股市预测_第3张图片
预测结果数值显示
数据挖掘 沪深股市预测_第4张图片

你可能感兴趣的:(不吐槽只学习)