Keras LSTM 时间序列预测
international-airline-passengers.csv数据记录:
time,passengers
"1949-01",112
"1949-02",118
"1949-03",132
"1949-04",129
"1949-05",121
"1949-06",135
"1949-07",148
"1949-08",148
"1949-09",136
"1949-10",119
"1949-11",104
"1949-12",118
"1950-01",115
"1950-02",126
"1950-03",141
"1950-04",135
"1950-05",125
"1950-06",149
"1950-07",170
"1950-08",170
"1950-09",158
"1950-10",133
"1950-11",114
"1950-12",140
"1951-01",145
"1951-02",150
"1951-03",178
"1951-04",163
"1951-05",172
"1951-06",178
"1951-07",199
"1951-08",199
"1951-09",184
"1951-10",162
"1951-11",146
"1951-12",166
"1952-01",171
"1952-02",180
"1952-03",193
"1952-04",181
"1952-05",183
"1952-06",218
"1952-07",230
"1952-08",242
"1952-09",209
"1952-10",191
"1952-11",172
"1952-12",194
"1953-01",196
"1953-02",196
"1953-03",236
"1953-04",235
"1953-05",229
"1953-06",243
"1953-07",264
"1953-08",272
"1953-09",237
"1953-10",211
"1953-11",180
"1953-12",201
"1954-01",204
"1954-02",188
"1954-03",235
"1954-04",227
"1954-05",234
"1954-06",264
"1954-07",302
"1954-08",293
"1954-09",259
"1954-10",229
"1954-11",203
"1954-12",229
"1955-01",242
"1955-02",233
"1955-03",267
"1955-04",269
"1955-05",270
"1955-06",315
"1955-07",364
"1955-08",347
"1955-09",312
"1955-10",274
"1955-11",237
"1955-12",278
"1956-01",284
"1956-02",277
"1956-03",317
"1956-04",313
"1956-05",318
"1956-06",374
"1956-07",413
"1956-08",405
"1956-09",355
"1956-10",306
"1956-11",271
"1956-12",306
"1957-01",315
"1957-02",301
"1957-03",356
"1957-04",348
"1957-05",355
"1957-06",422
"1957-07",465
"1957-08",467
"1957-09",404
"1957-10",347
"1957-11",305
"1957-12",336
"1958-01",340
"1958-02",318
"1958-03",362
"1958-04",348
"1958-05",363
"1958-06",435
"1958-07",491
"1958-08",505
"1958-09",404
"1958-10",359
"1958-11",310
"1958-12",337
"1959-01",360
"1959-02",342
"1959-03",406
"1959-04",396
"1959-05",420
"1959-06",472
"1959-07",548
"1959-08",559
"1959-09",463
"1959-10",407
"1959-11",362
"1959-12",405
"1960-01",417
"1960-02",391
"1960-03",419
"1960-04",461
"1960-05",472
"1960-06",535
"1960-07",622
"1960-08",606
"1960-09",508
"1960-10",461
"1960-11",390
"1960-12",432
Keras LSTM时间序列lstm_airline_predict.py:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense, Activation
def load_data(file_name, sequence_length=10, split=0.8):
df = pd.read_csv(file_name, sep=',', usecols=[1])
data_all = np.array(df).astype(float)
scaler = MinMaxScaler()
data_all = scaler.fit_transform(data_all)
data = []
for i in range(len(data_all) - sequence_length - 1):
data.append(data_all[i: i + sequence_length + 1])
reshaped_data = np.array(data).astype('float64')
np.random.shuffle(reshaped_data)
# 对x进行统一归一化,而y则不归一化
x = reshaped_data[:, :-1]
y = reshaped_data[:, -1]
split_boundary = int(reshaped_data.shape[0] * split)
train_x = x[: split_boundary]
test_x = x[split_boundary:]
train_y = y[: split_boundary]
test_y = y[split_boundary:]
return train_x, train_y, test_x, test_y, scaler
def build_model():
# input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim)
model = Sequential()
model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True))
print(model.layers)
model.add(LSTM(100, return_sequences=False))
model.add(Dense(output_dim=1))
model.add(Activation('linear'))
model.compile(loss='mse', optimizer='rmsprop')
return model
def train_model(train_x, train_y, test_x, test_y):
model = build_model()
try:
model.fit(train_x, train_y, batch_size=512, nb_epoch=30, validation_split=0.1)
predict = model.predict(test_x)
predict = np.reshape(predict, (predict.size, ))
except KeyboardInterrupt:
print(predict)
print(test_y)
print(predict)
print(test_y)
try:
fig = plt.figure(1)
plt.plot(predict, 'r:')
plt.plot(test_y, 'g-')
plt.legend(['predict', 'true'])
except Exception as e:
print(e)
return predict, test_y
if __name__ == '__main__':
train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv')
train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))
predict_y, test_y = train_model(train_x, train_y, test_x, test_y)
predict_y = scaler.inverse_transform([[i] for i in predict_y])
test_y = scaler.inverse_transform(test_y)
fig2 = plt.figure(2)
plt.plot(predict_y, 'g:')
plt.plot(test_y, 'r-')
plt.show()
运行结果:
Epoch 1/30
95/95 [==============================] - 5s 53ms/step - loss: 0.1793 - val_loss: 0.1028
Epoch 2/30
95/95 [==============================] - 0s 412us/step - loss: 0.1015 - val_loss: 0.0528
Epoch 3/30
95/95 [==============================] - 0s 353us/step - loss: 0.0532 - val_loss: 0.0183
Epoch 4/30
95/95 [==============================] - 0s 359us/step - loss: 0.0204 - val_loss: 0.0113
Epoch 5/30
95/95 [==============================] - 0s 448us/step - loss: 0.0145 - val_loss: 0.0119
Epoch 6/30
95/95 [==============================] - 0s 507us/step - loss: 0.0140 - val_loss: 0.0114
Epoch 7/30
95/95 [==============================] - 0s 439us/step - loss: 0.0135 - val_loss: 0.0120
Epoch 8/30
95/95 [==============================] - 0s 373us/step - loss: 0.0132 - val_loss: 0.0118
Epoch 9/30
95/95 [==============================] - 0s 454us/step - loss: 0.0129 - val_loss: 0.0127
Epoch 10/30
95/95 [==============================] - 0s 413us/step - loss: 0.0129 - val_loss: 0.0127
Epoch 11/30
95/95 [==============================] - 0s 418us/step - loss: 0.0129 - val_loss: 0.0147
Epoch 12/30
95/95 [==============================] - 0s 369us/step - loss: 0.0139 - val_loss: 0.0145
Epoch 13/30
95/95 [==============================] - 0s 485us/step - loss: 0.0141 - val_loss: 0.0182
Epoch 14/30
95/95 [==============================] - 0s 459us/step - loss: 0.0166 - val_loss: 0.0146
Epoch 15/30
95/95 [==============================] - 0s 549us/step - loss: 0.0138 - val_loss: 0.0168
Epoch 16/30
95/95 [==============================] - 0s 423us/step - loss: 0.0149 - val_loss: 0.0141
Epoch 17/30
95/95 [==============================] - 0s 401us/step - loss: 0.0129 - val_loss: 0.0155
Epoch 18/30
95/95 [==============================] - 0s 383us/step - loss: 0.0134 - val_loss: 0.0141
Epoch 19/30
95/95 [==============================] - 0s 328us/step - loss: 0.0125 - val_loss: 0.0154
Epoch 20/30
95/95 [==============================] - 0s 401us/step - loss: 0.0130 - val_loss: 0.0144
Epoch 21/30
95/95 [==============================] - 0s 338us/step - loss: 0.0124 - val_loss: 0.0158
Epoch 22/30
95/95 [==============================] - 0s 359us/step - loss: 0.0131 - val_loss: 0.0148
Epoch 23/30
95/95 [==============================] - 0s 338us/step - loss: 0.0126 - val_loss: 0.0164
Epoch 24/30
95/95 [==============================] - 0s 380us/step - loss: 0.0135 - val_loss: 0.0150
Epoch 25/30
95/95 [==============================] - 0s 378us/step - loss: 0.0127 - val_loss: 0.0167
Epoch 26/30
95/95 [==============================] - 0s 541us/step - loss: 0.0137 - val_loss: 0.0151
Epoch 27/30
95/95 [==============================] - 0s 528us/step - loss: 0.0127 - val_loss: 0.0166
Epoch 28/30
95/95 [==============================] - 0s 423us/step - loss: 0.0134 - val_loss: 0.0150
Epoch 29/30
95/95 [==============================] - 0s 515us/step - loss: 0.0125 - val_loss: 0.0164
Epoch 30/30
95/95 [==============================] - 0s 457us/step - loss: 0.0131 - val_loss: 0.0150
[0.6991743 0.4155811 0.43763575 0.1943914 0.24489456 0.43544254
0.728908 0.27704275 0.7644203 0.24740852 0.58411294 0.33986062
0.28997922 0.13274276 0.74714196 0.5237809 0.36774576 0.5282971
0.23951268 0.6239692 0.15398878 0.4958876 0.10568523 0.55706674
0.32880494 0.60746497 0.294434 ]
[[1. ]
[0.25675676]
[0.4034749 ]
[0.11969112]
[0.17374517]
[0.58108108]
[0.4980695 ]
[0.25675676]
[0.55405405]
[0.17760618]
[0.5 ]
[0.31853282]
[0.2992278 ]
[0.01930502]
[0.58108108]
[0.48648649]
[0.4015444 ]
[0.38030888]
[0.13127413]
[0.61003861]
[0.18339768]
[0.38996139]
[0.12741313]
[0.63899614]
[0.40733591]
[0.87837838]
[0.20656371]]