输入特征可以根据实际情况进行选择,这里选择的输入为[“收盘价”,“最高价”,“最低价”],对未来的收盘价进行预测。
def preprocess_data(data, time_len, rate, seq_len, pre_len):
train_size = int(time_len * rate)
train_data = data[0:train_size]
test_data = data[int(time_len*(rate)):time_len]
trainX, trainY, valX,valY,testX, testY = [], [], [], [],[],[]
for i in range(len(train_data) - seq_len - pre_len+1):
a = train_data[i: i + seq_len + pre_len]
trainX.append(a[0: seq_len])
trainY.append(a[seq_len: seq_len + pre_len])
for i in range(len(test_data) - seq_len - pre_len+1):
b = test_data[i: i + seq_len + pre_len]
testX.append(b[0: seq_len])
testY.append(b[seq_len: seq_len + pre_len])
trainX1 = np.array(trainX)
trainY1 = np.array(trainY)
testX1 = np.array(testX)
testY1 = np.array(testY)
return trainX1, trainY1,testX1, testY1
def metric(pred, label):
with np.errstate(divide = 'ignore', invalid = 'ignore'):
mask = np.not_equal(label, 0)
mask = mask.astype(np.float32)
mask /= np.mean(mask)
mae = np.abs(np.subtract(pred, label)).astype(np.float32)
rmse = np.square(mae)
mape = np.divide(mae, label)
mae = np.nan_to_num(mae * mask)
wape = np.divide(np.sum(mae), np.sum(label))
mae = np.mean(mae)
rmse = np.nan_to_num(rmse * mask)
rmse = np.sqrt(np.mean(rmse))
mape = np.nan_to_num(mape * mask)
mape = np.mean(mape)
return mae, rmse, mape
class LSTM(nn.Module):
def __init__(self,feature):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=feature,hidden_size=8,batch_first=True)
self.out = nn.Linear(8,1)
def forward(self,x):
x,_ = self.lstm(x)
x = self.out(x[:,-1,:])
return x
for epoch in range(100):
loss_all = 0
for x,y in train_dataloader:
pre = model(x)
loss = criterion(pre*std+mean,y)
loss_all +=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
for x,y in test_dataloader:
pre = model(x)*std+mean
pre_list.append(pre.item())
real_list.append(y.item())
mae, rmse, mape = metric(np.array(pre_list),np.array(real_list))
plt.figure(figsize=(20,8))
plt.plot(range(len(pre_list)),pre_list,color ="red",label ="pre")
plt.plot(range(len(real_list)),real_list,color ="blue",label ="real")
plt.legend()
plt.savefig("res.png")
plt.show()