基于MNIST数据集的手写数字识别项目2

GZ格式的MNIST数据集手写数字识别项目

  • 资源文件
  • 一、MNIST数据集介绍
  • 二、开发步骤
    • 1.引入库
    • 2.加载训练集验证集测试集
    • 3.显示训练集验证集测试集
    • 4.显示部分数据集中的数字图片
    • 5.构建网络模型
    • 6.编译模型
    • 7.训练模型
    • 8.绘制训练和验证结果

资源文件

gz格式的MNIST数据集


一、MNIST数据集介绍

(1)数据集有60000张
(2)每张图片大小28*28
(3)颜色通道:1(灰度)
(4)像素取值范围[0,255],0代表黑色,255代表白色
(5)每张图片有一个标签:0-9

二、开发步骤

1.引入库

import matplotlib.pyplot as plt
import numpy as np
from keras.layers import Conv2D,Input,LeakyReLU,Dense,Activation,Flatten,Dropout,MaxPool2D
from keras import models
from keras.optimizers import Adam,RMSprop
from tensorflow.examples.tutorials.mnist import input_data

2.加载训练集验证集测试集

mnist = input_data.read_data_sets('/BASICCNN/MNIST_Data_Gather/MNIST_data/',one_hot=True)
Train_images = mnist.train.images.reshape([mnist.train.num_examples,28,28,1])
Train_labels = mnist.train.labels
Val_images = mnist.validation.images.reshape([mnist.validation.num_examples,28,28,1])
Val_labels = mnist.validation.labels
Test_images = mnist.test.images.reshape([mnist.test.num_examples,28,28,1])
Test_labels = mnist.test.labels
print(Train_images.shape)
print(Train_labels.shape)
plt.show()

在这里插入图片描述

3.显示训练集验证集测试集

XTrain = []
YTrain = []
for i in range(10):
    x = i
    y = np.sum(Train_labels[:,i]== 1)
    XTrain.append(x)
    YTrain.append(y)
    plt.text(x,y,'%s' % y,horizontalalignment='center',fontsize=14)
plt.bar(XTrain,YTrain,width=0.8,color='orange') #柱状图参数设置
plt.tick_params(labelsize=14)
plt.xticks(XTrain)
plt.xlabel('Digits',fontsize=16)
plt.ylabel('Frequency',fontsize=16)
plt.title('Frequency in Train Data',fontsize=20)
plt.savefig('/BASICCNN/TrainImage/MNIST_traingz.png')
plt.show()

XVal = []
YVal = []
for i in range(10):
    x = i
    y = np.sum(Val_labels[:,i]== 1)
    XVal.append(x)
    YVal.append(y)
    plt.text(x,y,'%s' % y,horizontalalignment='center',fontsize=14)
plt.bar(XVal,YVal,width=0.8,color='red') #柱状图参数设置
plt.tick_params(labelsize=14)
plt.xticks(XVal)
plt.xlabel('Digits',fontsize=16)
plt.ylabel('Frequency',fontsize=16)
plt.title('Frequency in Val Data',fontsize=20)
plt.savefig('/BASICCNN/TrainImage/MNIST_valgz.png')
plt.show()

基于MNIST数据集的手写数字识别项目2_第1张图片
基于MNIST数据集的手写数字识别项目2_第2张图片

4.显示部分数据集中的数字图片

rows = 5
cols = 6
fig = plt.figure(figsize=(cols,rows))
for i in range(rows*cols):
    fig.add_subplot(rows,cols,i+1)  #图片添加到相应的位置
    img = mnist.train.images[i].reshape(28, 28)
    plt.imshow(img,cmap='PuOr')
    plt.axis('off')
    plt.title(str(Train_labels[i].argmax()),y=-0.25,color='blue') #显示对应标签
plt.savefig('/BASICCNN/TrainImage/MNIST_showgz.png')
plt.show()

基于MNIST数据集的手写数字识别项目2_第3张图片

5.构建网络模型

model = models.Sequential()
model.add(Conv2D(32,(3,3),padding='same',input_shape=(28,28,1)))
model.add(LeakyReLU())
model.add(Conv2D(32,(3,3),padding='same'))
model.add(LeakyReLU())
model.add(MaxPool2D(pool_size=(2,2)))

model.add(Conv2D(64,(3,3),padding='same',input_shape=(28,28,1)))
model.add(LeakyReLU())
model.add(Conv2D(64,(3,3),padding='same'))
model.add(LeakyReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.25))

model.add(Flatten()) #维度拉平
model.add(Dense(128,activation='relu'))
model.add(Dense(64,activation='relu'))
model.add(Dense(10,activation='softmax'))  #sigmoid归一[0,1]

6.编译模型

lr = 0.001 #学习率
loss = 'categorical_crossentropy' #损失函数
model.compile(Adam(lr=lr),loss=loss,metrics=['accuracy'])
model.summary()

基于MNIST数据集的手写数字识别项目2_第4张图片

7.训练模型

epochs = 2
batch_size = 32
history = model.fit(Train_images,Train_labels,batch_size=batch_size,epochs=epochs,validation_data=(Val_images,Val_labels))
model.save('/BASICCNN/TrainModel_h5/MNISTTraingz.h5')

基于MNIST数据集的手写数字识别项目2_第5张图片
在这里插入图片描述

8.绘制训练和验证结果

fig = plt.figure(figsize=(20,7))
plt.plot(history.epoch,history.history['accuracy'],label='Train Accuracy')
plt.plot(history.epoch,history.history['val_accuracy'],label='Val Accuracy')
plt.title('Accuracy Curve',fontsize=18)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.legend()
plt.savefig('/BASICCNN/TrainImage/MNISTTraingz_accuracy.png')
plt.show()

plt.plot(history.epoch,history.history['loss'],label='Train Loss')
plt.plot(history.epoch,history.history['val_loss'],label='Val Loss')
plt.title('Loss Curve',fontsize=18)
plt.xlabel('Epochs',fontsize=15)
plt.ylabel('Loss',fontsize=15)
plt.legend()
plt.savefig('/BASICCNN/TrainImage/MNISTValgz_loss.png')
plt.show()

基于MNIST数据集的手写数字识别项目2_第6张图片
基于MNIST数据集的手写数字识别项目2_第7张图片
测试及可视化部分参考:基于MNIST数据集的手写数字识别项目1

你可能感兴趣的:(机器学习项目实践之旅,计算机视觉,深度学习,MNIST数据集,手写数字识别)