[LSTM学习笔记4]How t o Develop Vanilla LSTMs

一.Vanilla LSTM
这是在LSTM原始论文中定义的标准的也是最简单的LSTM模型,可以用于解决简单的序列问题,其结构如下图:
[LSTM学习笔记4]How t o Develop Vanilla LSTMs_第1张图片
1.Keras实现

model = Sequential()  
model.add(LSTM(..., input_shape=(...)))  
model.add(Dense(...))

2.实例:Echo Sequence Predition Problem
给定输入序列,预测给定time step的输出,类似于:

[LSTM学习笔记4]How t o Develop Vanilla LSTMs_第2张图片

代码实现:

from random import randint
from numpy import array
from numpy import argmax
import keras
from keras.layers import Dense,LSTM
from keras.models import Sequential


def generate_sequence(length,n_features):
    return [randint(0,n_features-1) for _ in range(length)]

# one hot encode sequence
def one_hot_encode(sequence, n_features):
    encoding = list ()
    for value in sequence:
        vector = [0 for _ in range(n_features)]
        vector[value] = 1
        encoding.append(vector)
    return array(encoding)

# decode a one hot encoded string
def one_hot_decode(encoded_seq):
    return [argmax(vector) for vector in encoded_seq]

# generate one example for an lstm
def generate_example(length, n_features, out_index):
    # generate sequence
    sequence = generate_sequence(length, n_features)
    # one hot encode
    encoded = one_hot_encode(sequence, n_features)
    # reshape sequence to be 3D
    X = encoded.reshape((1, length, n_features))
    # select output
    y = encoded[out_index].reshape(1, n_features)
    return X, y

#define model
length = 5
n_features = 10
out_index = 2
model = Sequential()
#25 memory units,fully connect layer(10 neurons)
model.add(LSTM(25,input_shape=(length,n_features)))
model.add(Dense(n_features,activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
print(model.summary())

#fit the model
for i in range(10000):
    X,y = generate_example(length,n_features,out_index)
    model.fit(X,y,epochs=1,verbose=2)

# evaluate model
correct = 0
for i in range(100):
    X, y = generate_example(length, n_features, out_index)
    yhat = model.predict(X)
    if one_hot_decode(yhat) == one_hot_decode(y):
        correct += 1
print('Accuracy:%f' %((correct/100)*100.0))

输出为:

[LSTM学习笔记4]How t o Develop Vanilla LSTMs_第3张图片 

[LSTM学习笔记4]How t o Develop Vanilla LSTMs_第4张图片

Accuracy:100.000000
# prediction on new data
X, y = generate_example(length=5, n_features=10, out_index=2)
yhat = model.predict(X)
print ('Sequence: %s' % [one_hot_decode(x) for x in X])
print ('Expected: %s'% one_hot_decode(y))
print ('Predicted : %s' % one_hot_decode(yhat))
Sequence: [[1, 6, 4, 4, 8]]
Expected: [4]
Predicted : [4]

 

 

你可能感兴趣的:(Python,Keras)