之前使用 tensorflow LSTM 对第二日的收盘信息进行了预测(预测其他信息也可以,只需要对train_y进行替换)。具体如下https://blog.csdn.net/sabc123/article/details/104859617
现在对第二日的跌涨信息进行预测。原理与之前类似,只有两处大的改动:
1、在处理train_y的时候不能直接拿第二的天的数据直接使用,需要与前一日的数据做差作为train_y。
2、网络模型要响应的改变,使用二分类或者多分类模型,代价函数选择相对于的函数即可。
3、对整体的数据进行简单的评估,准确率大概60 70左右吧的后面还要优化
文字不太多,把部分代码放出了,整个项目的代码第一篇里面可以看到,后面会增加训练信息,时间有限,慢慢更新:
1、数据处理部分
原始数据大概如下:
ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount
000002.SZ,20200312,30.01,30.63,29.6,29.95,30.38,-0.43,-1.4154,593494.28,1776382.654
000002.SZ,20200311,30.83,31.0,30.36,30.38,30.8,-0.42,-1.3636,563318.6,1728508.5869999998
000002.SZ,20200310,30.6,31.57,30.2,30.8,30.55,0.25,0.8183,765993.64,2358611.141
000002.SZ,20200309,30.6,31.13,30.01,30.55,31.13,-0.58,-1.8632,843613.44,2579914.142
000002.SZ,20200306,31.81,31.99,30.99,31.13,32.3,-1.17,-3.6223,679528.96,2124469.847
000002.SZ,20200305,31.78,32.7,31.6,32.3,32.26,0.04,0.124,924170.06,2966839.995
000002.SZ,20200304,31.28,32.45,30.74,32.26,31.1,1.16,3.7299,1129796.24,3595379.313
000002.SZ,20200303,31.51,31.63,30.43,31.1,31.13,-0.03,-0.0964,942831.95,2913732.843000001
000002.SZ,20200302,29.9,31.63,29.9,31.13,29.59,1.54,5.2045,1375747.0,4275882.3489999985
000002.SZ,20200228,29.25,30.6,29.18,29.59,29.67,-0.08,-0.2696,1101299.76,3303591.4
000002.SZ,20200227,30.23,30.26,29.38,29.67,30.11,-0.44,-1.4613,922758.82,2745257.988
000002.SZ,20200226,28.85,30.96,28.6,30.11,29.33,0.78,2.6594,1654456.71,5002513.795
000002.SZ,20200225,28.82,29.65,28.8,29.33,28.9,0.43,1.4879,1052970.0,3080948.536000001
需要去除 ts_code 列
def columnSplit(verify=False):
"""
将数据按照列的方式拆分成 x和y
verify :验证y 的日期是否正确 开启时 将带上y对应的日期
:return:
"""
all = list(daily.keys());
column_y=['ts_code']
column_x = [x for x in all if x not in column_y]
if verify:# 带上日期 验证 日期是否对应
column_y = ["close",'trade_date']
else:
column_y = ["close"]
return column_x ,column_y
def dataProcess(self,df:pd.DataFrame)-> (np.array,np.array):
"""
:param df: 包含stock信息的 dataframe
:return:
"""
# 4.对数据按照表格拆分
column_x, column_y = columnSplit(self.verify)
x, y_t = np.array( df[column_x]),np.array( df[column_y])
# 5拼凑数据 x的最后一行没有预测值 y的第一行没有 训练值
y_value= [ y_t[i+1][0]-y_t[i][0] for i in range(len(y_t)-1)]
y_value= np.int32( np.array( y_value)> 0).reshape(-1, 1)
if self.verify:
y_date= [str(y_t[i][1]) + "->" + str(y_t[i + 1][1]) + " :" + str(y_t[i + 1][0]) + "-" + str(y_t[i][0]) for i in range(len(y_t)-1)]
y_date = np.array(y_date).reshape(-1, 1)
y_value = np.concatenate(( y_date ,y_value),axis=1)
return x[: len(x)-1], y_value
2、网络模型:我使用的是多分类 ,loss选择了SparseCategoricalCrossentropy函数,让后进行训练
def classifyModel(shape):
inn = keras.Input(shape=shape)
lstm1 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(inn)
lstm2 = keras.layers.LSTM(units=500, activation='tanh', return_sequences=True)(lstm1)
lstm3 = keras.layers.LSTM(units=200, activation='tanh', return_sequences=True)(lstm2)
lstm4 = keras.layers.LSTM(units=50, activation='tanh', return_sequences=True)(lstm3)
flatten = keras.layers.Flatten()(lstm4)
Dense1 = keras.layers.Dense(units=200, activation="relu")(flatten)
ott = keras.layers.Dense(units=3)(Dense1)
model = keras.Model(inputs=inn, outputs=ott)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
# keras.utils.plot_model(model,"picture/classify_model.png",show_shapes=True)
return model
def trainAll(self):
for data in self.dataPre.dataGenerator(0):
train_x, train_y, stock_list = data
print(train_x[0].shape)
print(train_x[1].shape)
print(train_x[2].shape)
print(train_y.shape)
history = self.model.fit(train_x, train_y, batch_size=self.batch_size,
epochs=self.epochs, validation_split=0.05)
self.model.save_weights(self.file_path)
stock_sql.updateTrianList(stock_list)
del train_x, train_y, stock_list
gc.collect()
gc.collect()
3、评估代码:
class classifyEvaluate:
def __init__(self):
self.trainMode = trainning. classifyTrain_1()
pass
def fitCount(self,realD:np.array,preD:np.array)-> int:
count=0
for i in range(len( realD)):
if realD[i]==preD[i]:
count+=1
return count
def ClearRecord(self):
stock_sql.ClearRecordData()
def upDownrate(self):
df = stock_sql.getStockFrame()
stockList = np.array(df)
fitConunt = {}
for i in range(prepare.RREDICT_LEN + 1):
fitConunt[str(i)] = 0
for stock in stockList[0:]:
realData, predictData = self.trainMode.predict(stockInfo(symbol=stock[0]))
predictData= [np.argmax(predictData[i]) for i in range(len(predictData))]
predictData =np.array(predictData).reshape((-1,1))
if realData.shape[0] != prepare.RREDICT_LEN:
continue
count= self.fitCount(realData,predictData)
print("symbol:" + str(stock[0]) + " name:" + str(stock[1]) + " fit:" + str(count))
fitConunt[str(count)] += 1
x = [i for i in range(prepare.RREDICT_LEN + 1)]
y = [fitConunt[str(d)] for d in range(prepare.RREDICT_LEN + 1)]
plt.plot(x, y)
plt.show()
print(fitConunt)