一、MNIST数据集是一个手写数字图片的数据集,其包含了60000张训练图片和10000张测试图片,这些图片是28 × × 28的灰度图片,共包含0到9总计10个数字。
import keras
from keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()
二、数据集的相关信息
print('shape of train images is ',train_images.shape)
print('shape of train labels is ',train_labels.shape)
print('train labels is ',train_labels)
print('shape of test images is ',test_images.shape)
print('shape of test labels is',test_labels.shape)
print('test labels is',test_labels)
shape of train images is (60000, 28, 28)
shape of train labels is (60000,)
train labels is [5 0 4 ... 5 6 8]
shape of test images is (10000, 28, 28)
shape of test labels is (10000,)
test labels is [7 2 1 ... 4 5 6]
三、设计网络结构(在Keras中layers是网络结构的基石)
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512,activation='relu',input_shape=(28*28,)))
network.add(layers.Dense(10,activation='softmax'))
四、指定optimizer、loss function和metrics并编译网络
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
五、处理数据
原始数据是存储在一个shape为(60000,28,28)的Numpy数组中,数据元素是处于[0,255] 的整数。而我们需要将其处理为shape为(60000,28*28)且元素之介于区间[0,1]的浮点数。
train_images = train_images.reshape((60000,28*28))
train_images = train_images.astype('float32')/255
test_images = test_images.reshape((10000,28*28))
test_images = test_images.astype('float32')/255
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
六、训练模型
network.fit(train_images,train_labels,epochs=5,batch_size=128)
Epoch 1/5
60000/60000 [==============================] - 4s 58us/step - loss: 0.2554 - acc: 0.9266
Epoch 2/5
60000/60000 [==============================] - 4s 61us/step - loss: 0.1018 - acc: 0.9700
Epoch 3/5
60000/60000 [==============================] - 4s 72us/step - loss: 0.0665 - acc: 0.9806
Epoch 4/5
60000/60000 [==============================] - 5s 76us/step - loss: 0.0486 - acc: 0.9850
Epoch 5/5
60000/60000 [==============================] - 5s 80us/step - loss: 0.0359 - acc: 0.9889
七、评估测试集
test_loss,test_acc = network.evaluate(test_images,test_labels)
print("test_loss:",test_loss)
print("test_acc:",test_acc)
10000/10000 [==============================] - 0s 38us/step
test_loss: 0.06354748234406579
test_acc: 0.9804