MNIST 数据集是经典的手写数字识别数据集
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
导入数据集的方法很多,简单介绍几种:
方法一:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
可以打印shape:
print('训练集图像的尺寸:', x_train.shape)
print('训练集标签的尺寸:', y_train.shape)
print('测试集图像的尺寸:', x_test.shape)
print('测试集标签的尺寸:', y_test.shape)
发现:
训练集图像的尺寸: (60000, 28, 28)
训练集标签的尺寸: (60000,)
测试集图像的尺寸: (10000, 28, 28)
测试集标签的尺寸: (10000,)
方法二:由于从官网直接下载数据集其格式为.npz,解压后发现其不是想象中的传统图片格式(例如.jpg等),这种格式的保存形式为二进制,需要将其转化,可以写个简单的函数将其plot出来:
def mnist_visualize_single(mode,idx):
if mode == 0:
plt.imshow(x_train[idx],cmap=plt.get_cmap("gray"))
title = 'label='+str(y_train[idx])
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
else:
plt.imshow(x_test[idx],cmap=plt.get_cmap("gray"))
title = 'label=' + str(y_test[idx])
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
mnist_visualize_single(mode=0, idx=0)
如上图所示,label显示有问题是因为我将其换成独热码了
预处理:
需要先将图像转换为四维矩阵用于网络训练,且需要把图像类型从Uint8转化为float32,提高训练精度。
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')
原始图像数据的像素灰度值范围是0-255,为了提高模型的训练精度,节省训练时间,通常将数值归一化至0-1。
x_train = x_train / 255
x_val = x_val / 255
x_test = x_test / 255
二:构建网络模型
网络模型采用卷积神经网络(CNN),使用Sequential框架进行搭建,由神经网络搭建六部法及CBAPD可知,可以定义函数:
def CNN_model():
model = Sequential()
model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(10, activation='softmax'))
print(model.summary())
return model
model = CNN_model()
当然也可以使用Class在初始化中实现,后续会补充
此时可以打印网络结构参数观察效果:
print(model.summary())
结果为
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 24, 24, 16) 416
max_pooling2d (MaxPooling2D (None, 12, 12, 16) 0
)
conv2d_1 (Conv2D) (None, 8, 8, 32) 12832
flatten (Flatten) (None, 2048) 0
dense (Dense) (None, 100) 204900
dense_1 (Dense) (None, 10) 1010
=================================================================
Total params: 219,158
Trainable params: 219,158
Non-trainable params: 0
_________________________________________________________________
None
三:编译训练网络
由神经网络八股及六步法可知,通过kears函数里的model.compile()可以配置网络,选择适当的损失函数,优化器等参数,至于为何这么选,我还没搞懂,后续会更新
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
3.1 训练网络
Keras可以通过多种函数训练网络,这里我们使用model.fit()
训练网络模型,函数中可以定义训练集数据与训练集标签,验证集数据与验证集标签、训练批次、批处理大小等
train_history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=32, verbose=2)
打印训练结果如下:
Epoch 1/20
1563/1563 - 16s - loss: 0.1495 - accuracy: 0.9540 - val_loss: 0.0663 - val_accuracy: 0.9812 - 16s/epoch - 11ms/step
Epoch 2/20
1563/1563 - 14s - loss: 0.0472 - accuracy: 0.9855 - val_loss: 0.0567 - val_accuracy: 0.9846 - 14s/epoch - 9ms/step
Epoch 3/20
1563/1563 - 14s - loss: 0.0315 - accuracy: 0.9898 - val_loss: 0.0370 - val_accuracy: 0.9898 - 14s/epoch - 9ms/step
Epoch 4/20
1563/1563 - 14s - loss: 0.0231 - accuracy: 0.9926 - val_loss: 0.0502 - val_accuracy: 0.9875 - 14s/epoch - 9ms/step
Epoch 5/20
1563/1563 - 14s - loss: 0.0172 - accuracy: 0.9944 - val_loss: 0.0425 - val_accuracy: 0.9901 - 14s/epoch - 9ms/step
Epoch 6/20
1563/1563 - 14s - loss: 0.0138 - accuracy: 0.9954 - val_loss: 0.0437 - val_accuracy: 0.9894 - 14s/epoch - 9ms/step
Epoch 7/20
1563/1563 - 14s - loss: 0.0116 - accuracy: 0.9961 - val_loss: 0.0463 - val_accuracy: 0.9894 - 14s/epoch - 9ms/step
Epoch 8/20
1563/1563 - 14s - loss: 0.0084 - accuracy: 0.9974 - val_loss: 0.0498 - val_accuracy: 0.9884 - 14s/epoch - 9ms/step
Epoch 9/20
1563/1563 - 14s - loss: 0.0080 - accuracy: 0.9971 - val_loss: 0.0420 - val_accuracy: 0.9907 - 14s/epoch - 9ms/step
Epoch 10/20
1563/1563 - 14s - loss: 0.0076 - accuracy: 0.9977 - val_loss: 0.0752 - val_accuracy: 0.9859 - 14s/epoch - 9ms/step
Epoch 11/20
1563/1563 - 14s - loss: 0.0077 - accuracy: 0.9975 - val_loss: 0.0423 - val_accuracy: 0.9912 - 14s/epoch - 9ms/step
Epoch 12/20
1563/1563 - 14s - loss: 0.0048 - accuracy: 0.9984 - val_loss: 0.0608 - val_accuracy: 0.9888 - 14s/epoch - 9ms/step
Epoch 13/20
1563/1563 - 14s - loss: 0.0059 - accuracy: 0.9981 - val_loss: 0.0627 - val_accuracy: 0.9889 - 14s/epoch - 9ms/step
Epoch 14/20
1563/1563 - 14s - loss: 0.0050 - accuracy: 0.9985 - val_loss: 0.0635 - val_accuracy: 0.9899 - 14s/epoch - 9ms/step
Epoch 15/20
1563/1563 - 14s - loss: 0.0052 - accuracy: 0.9984 - val_loss: 0.0485 - val_accuracy: 0.9911 - 14s/epoch - 9ms/step
Epoch 16/20
1563/1563 - 14s - loss: 0.0032 - accuracy: 0.9990 - val_loss: 0.0773 - val_accuracy: 0.9886 - 14s/epoch - 9ms/step
Epoch 17/20
1563/1563 - 14s - loss: 0.0051 - accuracy: 0.9985 - val_loss: 0.0606 - val_accuracy: 0.9904 - 14s/epoch - 9ms/step
Epoch 18/20
1563/1563 - 14s - loss: 0.0045 - accuracy: 0.9985 - val_loss: 0.0626 - val_accuracy: 0.9904 - 14s/epoch - 9ms/step
Epoch 19/20
1563/1563 - 14s - loss: 0.0047 - accuracy: 0.9988 - val_loss: 0.0663 - val_accuracy: 0.9906 - 14s/epoch - 9ms/step
Epoch 20/20
1563/1563 - 14s - loss: 0.0030 - accuracy: 0.9991 - val_loss: 0.0851 - val_accuracy: 0.9889 - 14s/epoch - 9ms/step
容易发现,随着迭代轮数的增加,loss在不断减小,acc接近99%,说明网络有效果。
3.2训练过程可视化
写一个简单的显示函数将acc与loss显示出来,更直观
def show_train_history(train_history, train, validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('Train History')
plt.ylabel(train)
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()
show_train_history(train_history, 'accuracy', 'val_accuracy')
show_train_history(train_history, 'loss', 'val_loss')
四:网络预测
数据集被分为训练集,验证集,测试集
通过函数model.evaluate()
测试神经网路在测试集上的情况
score = model.evaluate(x_test, y_test)
打印测试集的loss与acc
print('Test loss:', score[0])
print('Test accuracy:', score[1])
需要注意,在进行这一步之前必须将y_test使用独热码进行转换
代码为
y_test = tf.squeeze(y_test) y_test = tf.one_hot(y_test, depth=10)
输出打印结果为:
313/313 [==============================] - 1s 3ms/step - loss: 0.0653 - accuracy: 0.9896
Test loss : 0.06529523432254791
Test loss : 0.9896000027656555
通过model.predict()
对测试集图像进行预测:
predictions = model.predict(x_test)
predictions = np.argmax(predictions, axis=1)
print('前20张图片预测结果:', predictions[:20])
前20张图片预测结果: [7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
编写显示前二十张图片的函数,如前面提到的方法二
def mnist_visualize_multiple(mode, start, end, length, width):
if mode == 0:
for i in range(start, end):
plt.subplot(length, width, 1 + i)
plt.imshow(x_train[i], cmap=plt.get_cmap('gray'))
title = 'label=' + str(y_train[i])
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
else:
for i in range(start, end):
plt.subplot(length, width, 1 + i)
plt.imshow(x_test[i], cmap=plt.get_cmap('gray'))
title = 'label=' + str(y_test[i])
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
同样可以将前二十张预测结果图像可视化
def mnist_visualize_multiple_predict(start, end, length, width):
for i in range(start, end):
plt.subplot(length, width, 1 + i)
plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))
title_true = 'true=' + str(y_test_original[i])
title_prediction = ',' + 'prediction' + str(model.predict_classes(np.expand_dims(x_test[i], axis=0)))
title = title_true + title_prediction
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
调用
mnist_visualize_multiple_predict(start=0, end=9, length=3, width=3)
限于篇幅,只显示九张
未完待续