基于keras的mnist手写体识别程序

大家好 我是来自河北大学 心电组的一名研一的学生,本篇文章是我对mnist识别学习的认识和分享。

本文主要用来给想要用keras搭建网络识别mnist的同学一个引导。

有错误的地方请大家指正

我会虚心接受

首先是库的安装,我选择的版本是tensorflow-gpu 2.6.0版本,大家如果和我的版本一样可以直接复制上面的代码进行库的导入。

from tensorflow.keras.datasets import mnist
from tensorflow.keras import models,layers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import optimizers
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

首先我们先看第一步:数据集导入和处理

#下载MNIST数据集
(train_datas,train_labels),(test_datas,test_labels)=mnist.load_data()

#将数据集进行预处理和归一化
train_datas=train_datas.reshape((60000,28,28,1))
train_datas=train_datas.astype('float32')/255
test_datas=test_datas.reshape((10000,28,28,1))
test_datas=test_datas.astype('float32')/255

#对标签进行热编码
train_labels=to_categorical(train_labels)
test_labels=to_categorical(test_labels)

我们首先通过keras库中自带的MNIST数据集进行数据的下载,数据已经是帮我们分好训练集和测试集的,训练集有60000张图片,测试集有10000张图片,全部都是28×28像素的单通道灰度图。

之后我们要对训练集和测试集的图片进行归一化处理,那为什么要归一化呢?

图片的数值在0~255之间,我们之所以让图片除以255是为了让图片的数值全部变成0~1,这样在我们后续网络中会更好的让分类器进行分类。

接下来就是要对数据集的标签进行热编码,这一步基本是必须做的,这样使特征之间的距离计算更加合理,理解不了的同学可以直接把这步操作记住

第二步:验证集的划分

#K折交叉验证
k=4
num_val_samples=len(train_datas)//k
for i in range(k):
    print('processing fold #',i)
    val_train_datas=train_datas[i*num_val_samples:(i+1)*num_val_samples]
    val_train_labels=train_labels[i*num_val_samples:(i+1)*num_val_samples]

    partial_train_datas=np.concatenate(
        [train_datas[:i*num_val_samples],
         train_datas[(i+1)*num_val_samples:]],
        axis=0
    )
    partial_train_labels=np.concatenate(
        [train_labels[:i*num_val_samples],
         train_labels[(i+1)*num_val_samples:]],
        axis=0
    )

这里我们使用K折交叉验证来将训练集划分出验证集,K折交叉验证的原理大家可以参考其他资料来进行学习

第三步:搭建网络模型

#定义网络模型
def network():
    model=models.Sequential()
    model.add(layers.Conv2D(16,(3,3),activation='relu',input_shape=(28,28,1)))
    model.add(layers.Conv2D(32,(3,3),activation='relu',padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.Dropout(0.5))
    model.add(layers.Conv2D(32,(3,3),activation='relu',padding='same'))
    model.add(layers.Conv2D(64,(3,3),activation='relu',padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.Dropout(0.5))
    model.add(layers.Conv2D(64,(3,3),activation='relu',padding='same'))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.Dropout(0.5))
    model.add(layers.Flatten())
    model.add(layers.Dense(512,activation='relu'))
    model.add(layers.Dense(10,activation='softmax'))
    model.compile(optimizer=optimizers.RMSprop(learning_rate=0.001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

这个模型是我基于CNN卷积神经网络的模型自己逐步改进的,大家不要觉得搭建网络很难,我们一般搭建网络都是先从很小的网络开始搭建,之后再不断的调整网络的超参数和加入网络其他层来使网络性能达到最优。

首先利用keras搭建模型第一步就是要先加入Sequential(),这一步的作用就是定义一个网络的框架,让我们有一个容器可以往里面放其他层。

接下来我们就是向这个框架中加入卷积层和池化层,Conv2D卷积层的各个参数大家不知道的可以参考其他资料自己解决,在网络的第一层我们要设置输入的形状,因为我们MNIST数据集的图片是28×28像素单通道的图片,所以这里input_shape我们设置为(28,28,1)

激活函数我们选取ReLU函数

为了防止模型过拟合,提升网络模型的分类性能,所以我们向网络中加入Dropout层来舍去一部分参数,这里Dropout我们设置为0.5,让它按照50%的比例来进行舍去。

之后经过平铺层展平后送入密集层进行特征提取,提取到的特征送入softmax分类器进行分类。最后一层Dense中之所以把参数设为10,是因为我们MNIST中0-9一共有10个不同的类别,所以我们这里设置为10,如果是有5个类别就设置为5,二分类就设置为2,但是二分类的分类器多用Sigmoid来分类。

网络搭建好之后,我们想要让网络自己调节超参数来修正偏差,如何来做呢?

这就需要向网络中设置优化器和损失函数,其中工作的原理大家可以自己去哔哩哔哩或者其他网站查询。

优化器我们选择REMprop,学习率设置为0.001,也可以表示为1e-3,

损失函数选择categroical_crossentropy交叉熵损失函数

现在我们来打印一下网络模型来看一下,大家自己写的时候可以不需要打印这个

基于keras的mnist手写体识别程序_第1张图片

 第四步:开始训练

#构造网络进行训练
model=network()
history=model.fit(partial_train_datas,partial_train_labels,epochs=60
                  ,batch_size=512,
                  validation_data=(val_train_datas,val_train_labels))

好了我们现在开始训练网络,因为我们网络搭建用的是一个函数定义的,所以一开始我们首先要先调用这个函数,然后开始训练。 

我们用来训练的数据是被划分完验证集之后的训练集,我们迭代次数可以随意调,这里我们设置迭代次数为60次,这是我多次调整之后确定的最佳迭代次数,介于过拟合和欠拟合的临界点。

批次我们设置为512为一批,意思是每次送入网络中训练512个为一组,网络的验证集就是我们之前K折交叉验证划分出来的验证集。基于keras的mnist手写体识别程序_第2张图片

 

第五步:绘制图像

#绘制损失图像
history_dict=history.history
loss_values=history_dict['loss']
val_loss_values=history_dict['val_loss']
epochs=range(1,len(loss_values)+1)
plt.plot(epochs,loss_values,'bo',label='train loss')
plt.plot(epochs,val_loss_values,'b',label='Validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

#绘制准确率图像
plt.figure()
history_dict=history.history
acc_values=history_dict['accuracy']
val_acc_values=history_dict['val_accuracy']
epochs=range(1,len(loss_values)+1)
plt.plot(epochs,acc_values,'ro',label='train acc')
plt.plot(epochs,val_acc_values,'r',label='Validation acc')
plt.title('Training and Validation acc')
plt.xlabel('Epochs')
plt.ylabel('acc')
plt.legend()
plt.show()

这里我们可以选择绘制Loss和Accuracy的变化曲线,方便我们更直观的观察我们网络的性能变化,大家可以选择绘制,也可以不选择绘制。具体每个参数什么意思大家需要靠自己去查阅学习。

基于keras的mnist手写体识别程序_第3张图片

基于keras的mnist手写体识别程序_第4张图片

 

 

通过图像我们可以看到在训练集和验证集上的Loss和Accuracy基本重合,证明我们的模型临界点找的很好。

我们把测试集输入到网络中做最终的测试,来看一看结果如何

#在测试集上验证性能
test_loss,test_acc=model.evaluate(test_datas,test_labels)
print('test_acc:{}',format(test_acc))
print('test_loss:{}'.format((test_loss)))

在测试集上的准确率非常高,说明我们这个网络搭建的很成功。大家需要不断的对网络层数增减以及调节超参数来寻找最优的模型,这个需要慢慢来。

训练完网络之后,我们来自己输入一张手写图片来看看识别的效果如何

基于keras的mnist手写体识别程序_第5张图片

 这是我自己在电脑上用画图软件手写的一个图,注意我们这个图片和MNIST的图的颜色正好相反,MNIST中的图片是黑底白字,我们平时的都是白底黑字,所以待会我们需要一段代码来把颜色调节一下。强调一下,图像中0表示白色,255表示黑色

#导入图片
img_path=r'D:\python解释器\pycharm环境\pythonProject3\data\MNIST\com-5.png'
image=Image.open(img_path)
#将图片变为28×28像素
image=image.resize((28,28),Image.ANTIALIAS)
#将图片转为单通道灰度图
image=image.convert('L')
#将灰度图转为numpy数组
image_rr=np.array(image.convert('L'))
#循环遍历,将图片中白色变成黑色,黑色变成白色
for i in range(28):
    for j in range(28):
        if image_rr[i][j]<100:
            image_rr[i][j]=255
        else:
            image_rr[i][j]=0
#再次对图像进行归一化处理
image_rr=image_rr/255
#将图像转为适合网络输入的形式
image_rr=np.reshape(image_rr,(1,28,28,1))
#对图片放入网络进行预测
result=model.predict(image_rr)
#将结果横向排列,取最大值
pred=np.argmax(result,axis=1)

print('您手写的数字是:{}'.format(pred))

好的我们来看一看识别效果

 到这里本文的全部内容就结束啦,有哪里有纰漏的欢迎大家多多指正。

另:本文的网络只能识别在电脑或者手机画板上写的数字,在纸上写的拍照传入网络中会识别不准,初步分析认为是因为MNIST数据集中并没有在纸上写的手写体,所以网络有些特征提取不到,导致识别错误。

谢谢大家点击支持。

你可能感兴趣的:(keras,tensorflow,深度学习)