个人主页:研学社的博客
欢迎来到本博客❤️❤️a
博主优势:博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。
⛳️座右铭:行百里者,半于九十。
本文目录如下:
目录
1 概述
2 运行结果
3 参考文献
4 Python代码实现
# 归一化,便与训练 train_data_numpy = np.array(train_data) train_mean = np.mean(train_data_numpy) train_std = np.std(train_data_numpy) train_data_numpy = (train_data_numpy - train_mean) / train_std train_data_tensor = torch.Tensor(train_data_numpy) # 创建 dataloader train_set = TrainSet(train_data_tensor) train_loader = DataLoader(train_set, batch_size=10, shuffle=True)
for i in range(DAYS_BEFORE, len(all_series)): x = all_series[i - DAYS_BEFORE:i] # 将 x 填充到 (bs, ts, is) 中的 timesteps x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=2) if torch.cuda.is_available(): x = x.cuda() y = rnn(x) if i < test_start: generate_data_train.append(torch.squeeze(y.cpu()).detach().numpy() * train_std + train_mean) else: generate_data_test.append(torch.squeeze(y.cpu()).detach().numpy() * train_std + train_mean) plt.figure(figsize=(12,8)) plt.plot(df_index[DAYS_BEFORE: TRAIN_END], generate_data_train, 'b', label='generate_train', ) plt.plot(df_index[TRAIN_END:], generate_data_test, 'k', label='generate_test') plt.plot(df_index, all_series.clone().numpy()* train_std + train_mean, 'r', label='real_data') plt.legend() plt.show()
plt.figure(figsize=(10,16)) plt.subplot(2,1,1) plt.plot(df_index[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], generate_data_train[100: 130], 'b', label='generate_train') plt.plot(df_index[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], (all_series.clone().numpy()* train_std + train_mean)[100 + DAYS_BEFORE: 130 + DAYS_BEFORE], 'r', label='real_data') plt.legend() plt.subplot(2,1,2) plt.plot(df_index[TRAIN_END + 50: TRAIN_END + 80], generate_data_test[50:80], 'k', label='generate_test') plt.plot(df_index[TRAIN_END + 50: TRAIN_END + 80], (all_series.clone().numpy()* train_std + train_mean)[TRAIN_END + 50: TRAIN_END + 80], 'r', label='real_data') plt.legend() plt.show()
print(len(all_series_test2)) print(len(df_index)) print(len(iter_series)) plt.figure(figsize=(12,8)) plt.plot(df_index[ : len(iter_series)], iter_series, 'b', label='generate_train') plt.plot(df_index, all_series_test2.clone().numpy() * train_std + train_mean, 'r', label='real_data') plt.legend() plt.show()
部分理论来源于网络,如有侵权请联系删除。
[1]曹彦彦. LSTM模型优化及其在股指预测中的应用研究[D].东北财经大学,2022.DOI:10.27006/d.cnki.gdbcu.2022.000020.
[2]张杰. 基于LSTM的股票预测实证分析[D].山东大学,2020.DOI:10.27272/d.cnki.gshdu.2020.002958.
[3]隋金城. 基于LSTM神经网络的股票预测研究[D].青岛科技大学,2020.DOI:10.27264/d.cnki.gqdhc.2020.000423.