下面来看几个例子,来了解一下Keras的便捷之处。不需要具体去研究代码的意思,只需要看一下这个实现过程。用编程的装饰模式把各个组件模块化,然后可以自己随意的拼装。首先介绍一个基于Keras做的手写MNIST识别的代码,剩下的就看一下实现过程即可。
0用Keras实现MNIST识别。
from keras.models import Sequential
from keras.layers.core import Dense, Dropout,Activation
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
model = Sequential()
model.add(Dense(500, input_shape=(784,), init='glorot_uniform')) # 输入层,28*28=784
model.add(Activation('tanh')) # 激活函数是tanh
model.add(Dropout(0.5)) # 采用50%的dropout
model.add(Dense(500, init='glorot_uniform')) # 隐层节点500个
model.add(Activation('tanh'))
model.add(Dropout(0.5))
# 输出结果是10个类别,所以维度是10
model.add(Dense(10))
model.add(Activation('softmax')) # 最后一层用softmax
# 设定学习率(lr)等参数
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9,nesterov=True)
# 使用交叉熵作为loss函数,就是熟知的log损失函数
model.compile(loss='categorical_crossentropy',
optimizer=sgd, class_mode='categorical')
# 也可以现行下载,然后加载本地文件
path = r"https://s3.amazonaws.com/img-datasets/mnist.npz"
f = np.load(path)
X_train = f['x_train']
Y_train = f['y_train']
X_test = f['x_test']
Y_test = f['y_test']
f.close()
# 使用Keras自带的mnist工具读取数据(第一次需要联网)
#(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 由于输入数据维度是(num, 28, 28),这里需要把后面的维度直接拼起来变成784维
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1]* X_train.shape[2])
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1]* X_test.shape[2])
# 这里需要把index转换成一个one hot的矩阵
Y_train = (np.arange(10) == Y_train[:,None]).astype(int)
Y_test = (np.arange(10) == Y_test[:,None]).astype(int)
# 开始训练,这里参数比较多。batch_size就是batch_size,nb_epoch就是最多迭代的次数, shuffle就是是否把数据随机打乱之后再进行训练
# verbose是屏显模式,官方这么说的:verbose: 0 forno logging to stdout, 1 for progress bar logging, 2 for one log line per epoch.
# 就是说0是不屏显,1是显示一个进度条,2是每个epoch都显示一行数据
# show_accuracy就是显示每次迭代后的正确率
# validation_split就是拿出百分之多少用来做交叉验证
model.fit(X_train, Y_train, batch_size=200, nb_epoch=20,shuffle=True, verbose=1, validation_split=0.3)
print('test set')
model.evaluate(X_test, Y_test, batch_size=200,verbose=1)
Epoch 86/100
42000/42000 [==============================] - 3s - loss: 0.3688 - val_loss: 0.2104
Epoch 87/100
42000/42000 [==============================] - 3s - loss: 0.3673 - val_loss: 0.2193
Epoch 88/100
42000/42000 [==============================] - 3s - loss: 0.3767 - val_loss: 0.2185
Epoch 89/100
42000/42000 [==============================] - 3s - loss: 0.3667 - val_loss: 0.2092
Epoch 90/100
42000/42000 [==============================] - 3s - loss: 0.3526 - val_loss: 0.2109
Epoch 91/100
42000/42000 [==============================] - 3s - loss: 0.3544 - val_loss: 0.2103
Epoch 92/100
42000/42000 [==============================] - 3s - loss: 0.3686 - val_loss: 0.2115
Epoch 93/100
42000/42000 [==============================] - 3s - loss: 0.3647 - val_loss: 0.2057
Epoch 94/100
42000/42000 [==============================] - 3s - loss: 0.3591 - val_loss: 0.2043
Epoch 95/100
42000/42000 [==============================] - 3s - loss: 0.3516 - val_loss: 0.2032
Epoch 96/100
42000/42000 [==============================] - 3s - loss: 0.3487 - val_loss: 0.2042
Epoch 97/100
42000/42000 [==============================] - 3s - loss: 0.3451 - val_loss: 0.2053
Epoch 98/100
42000/42000 [==============================] - 3s - loss: 0.3450 - val_loss: 0.2061
Epoch 99/100
42000/42000 [==============================] - 3s - loss: 0.3491 - val_loss: 0.2033
Epoch 100/100
42000/42000 [==============================] - 3s - loss: 0.3400 - val_loss: 0.2069
test set
8800/10000 [=========================>....] - ETA: 0sOut[5]: 0.19544119497761131