本文介绍如何基于keras采用RNN和LSTM对IMDB数据集进行分类。
示例代码:
from keras.layers import SimpleRNN
from keras.models import Sequential
from keras.layers import Embedding, SimpleRNN
model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32))
print(model.summary())
model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32, return_sequences=True))
print(model.summary())
model = Sequential()
model.add(Embedding(10000, 32))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32, return_sequences=True))
model.add(SimpleRNN(32))
print(model.summary())
from keras.datasets import imdb
from keras.preprocessing import sequence
max_features = 10000
maxlen = 500
batch_size = 32
print('Loading data......')
(input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
print(len(input_train), 'train sequences')
print(len(input_test), 'test sequences')
print('Pad Sequences (samples x time')
input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
print('input_train shape:', input_train.shape)
print('input_test shape:', input_test.shape)
from keras.layers import Dense
model = Sequential()
model.add(Embedding(max_features, 32))
model.add(SimpleRNN(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
hist = model.fit(input_train, y_train,
epochs=10,
batch_size=128,
validation_split=0.2)
import matplotlib.pyplot as plt
acc = hist.history['acc']
val_acc = hist.history['val_acc']
loss = hist.history['loss']
val_loss = hist.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
from keras.layers import LSTM
model = Sequential()
model.add(Embedding(max_features, 32))
model.add(LSTM(32))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['acc'])
hist = model.fit(input_train, y_train,
epochs=10,
batch_size=128,
validation_split=0.2)
acc = hist.history['acc']
val_acc = hist.history['val_acc']
loss = hist.history['loss']
val_loss = hist.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
测试结果:
16128/20000 [=======================>......] - ETA: 5s - loss: 0.0187 - acc: 0.9954
16256/20000 [=======================>......] - ETA: 5s - loss: 0.0186 - acc: 0.9954
16384/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954
16512/20000 [=======================>......] - ETA: 4s - loss: 0.0186 - acc: 0.9954
16640/20000 [=======================>......] - ETA: 4s - loss: 0.0185 - acc: 0.9954
16768/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955
16896/20000 [========================>.....] - ETA: 4s - loss: 0.0184 - acc: 0.9955
17024/20000 [========================>.....] - ETA: 4s - loss: 0.0186 - acc: 0.9954
17152/20000 [========================>.....] - ETA: 3s - loss: 0.0189 - acc: 0.9953
17280/20000 [========================>.....] - ETA: 3s - loss: 0.0188 - acc: 0.9953
17408/20000 [=========================>....] - ETA: 3s - loss: 0.0189 - acc: 0.9952
17536/20000 [=========================>....] - ETA: 3s - loss: 0.0188 - acc: 0.9953
17664/20000 [=========================>....] - ETA: 3s - loss: 0.0187 - acc: 0.9953
17792/20000 [=========================>....] - ETA: 2s - loss: 0.0187 - acc: 0.9953
17920/20000 [=========================>....] - ETA: 2s - loss: 0.0186 - acc: 0.9953
18048/20000 [==========================>...] - ETA: 2s - loss: 0.0186 - acc: 0.9953
18176/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954
18304/20000 [==========================>...] - ETA: 2s - loss: 0.0184 - acc: 0.9954
18432/20000 [==========================>...] - ETA: 2s - loss: 0.0185 - acc: 0.9954
18560/20000 [==========================>...] - ETA: 1s - loss: 0.0186 - acc: 0.9954
18688/20000 [===========================>..] - ETA: 1s - loss: 0.0185 - acc: 0.9954
18816/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9954
18944/20000 [===========================>..] - ETA: 1s - loss: 0.0184 - acc: 0.9955
19072/20000 [===========================>..] - ETA: 1s - loss: 0.0186 - acc: 0.9954
19200/20000 [===========================>..] - ETA: 1s - loss: 0.0188 - acc: 0.9953
19328/20000 [===========================>..] - ETA: 0s - loss: 0.0190 - acc: 0.9953
19456/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952
19584/20000 [============================>.] - ETA: 0s - loss: 0.0196 - acc: 0.9951
19712/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9951
19840/20000 [============================>.] - ETA: 0s - loss: 0.0195 - acc: 0.9952
19968/20000 [============================>.] - ETA: 0s - loss: 0.0194 - acc: 0.9952
20000/20000 [==============================] - 29s 1ms/step - loss: 0.0194 - acc: 0.9952 - val_loss: 0.6177 - val_acc: 0.8292