keras入门系列(二)——实现卷积神经网络手写数字的分类识别

【参考视频网址:】https://space.bilibili.com/45151802/video
(老师讲的特别好,良心推荐)
【mnist数据集下载:】https://download.csdn.net/download/weixin_41874898/11434107
【github代码下载:】https://github.com/Seasea77/keras_small_project_19_07_26

文章目录

      • 1. 实现功能:使用keras实现卷积神经网络手写数字的分类识别
      • 2. 文件目录
      • 3. keras_mnist_cnn_train.py
      • 4. keras_mnist_cnn_predict.py
      • 5.识别结果

1. 实现功能:使用keras实现卷积神经网络手写数字的分类识别

2. 文件目录

keras入门系列(二)——实现卷积神经网络手写数字的分类识别_第1张图片

3. keras_mnist_cnn_train.py

功能:生成conv_model.h5文件

# coding:utf-8


import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils  # 统一处理numpy数据的工具
from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense, Dropout
# MaxPool2D和MaxPooling2D的区别??
from keras.optimizers import Adam
from keras.models import Sequential


nb_class = 10
nb_epoch = 2
batch_size = 1024  # 128,内存小改为64,内存大改大些,跑起来速度变化不大就行。
# 这个参数有时很有个性,需要调节试试,有时loss很快,改一下有时loss就会很小
# 调参还可以调节lr学习率.


"""1数据读取与处理"""
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

X_train = X_train.reshape(-1, 28, 28, 1)    # channel_last通道数在最后一个维度;-1看成未知
X_test = X_test.reshape(-1, 28, 28, 1)    # channel_last通道数在最后一个维度;-1看成未知

# # 使用下面两条语句查看X_train是否归一化处理了
# a_csv = X_train.reshape(-1, 784)
# np.savetxt("a.csv", a_csv, delimiter=",")  # 1个多G。
# 实际是没有转

X_train = X_train / 255.  # 必须是浮点数,不然很难收敛。
Y_train = np_utils.to_categorical(Y_train, nb_class)
Y_test = np_utils.to_categorical(Y_test, nb_class)


"""2网络搭建(卷积网络+全连接网络)"""
model = Sequential()

# 1st Conv2D layer
model.add(Convolution2D(
    filters=32,
    kernel_size=(5, 5),
    padding="same",
    input_shape=(28, 28, 1),
))

model.add(Activation("relu"))

model.add(MaxPool2D(
    pool_size=(2, 2),
    strides=(2, 2),
    padding="same",
))

# 2nd Conv2D layer
model.add(Convolution2D(
    filters=64,
    kernel_size=(5, 5),
    padding="same"
))

model.add(Activation("relu"))

model.add(MaxPool2D(
    pool_size=(2, 2),
    strides=(2, 2),
    padding="same",
))
# 卷积完成后生成的形状为[[[1,2,3],[2,3,4][4,5,6]]],在此项目中,因为是灰度图,所以是3维,一般是4维
# 根据实际情况加卷积层数。

# 1st Fully connected layer
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
# model.add(Dropout(0.2))

# 2nd Fully connected layer
model.add(Dense(nb_class))
model.add(Activation("softmax"))


"""3编译"""
adam = Adam(lr=0.01)  # 实例化, 0.01根据实际情况来
model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)


"""4启动网络-训练网络"""
model.fit(
    x=X_train,
    y=Y_train,
    batch_size=batch_size,
    epochs=nb_epoch,  # 输入一个字母,有等号的提示是需要参数(最前面有个小提示P),
    # 没等号的提示是已经定义好的变量(最前面有个小提示V)。
    verbose=1,
    validation_data=[X_test, Y_test],
)
model.save("./conv_model.h5")  # 模型的保存
# 旧版本中写下面实现上面validation_data=[X_test, Y_test]的功能
# evaluation = model.evaluate(X_test, Y_test)
# print(evaluation)

4. keras_mnist_cnn_predict.py

功能:加载h5文件,实现对test.jpg(灰度图像)的预测

# coding:utf-8

import numpy as np
from keras.models import load_model
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
from PIL import Image

model = load_model("conv_model.h5")


class PredictImg(object):
    def __init__(self):
        pass

    # filename为图片名字
    def pred(self, filename):
        image = Image.open(filename)
        image_L = image.convert("L")
        image_L = image_L.resize((28, 28), Image.ANTIALIAS)
        image_L = np.array(image_L)
        image_L = image_L / 255.  # 如果少了这一步,那么返回的prediction的10个值,要么是0要么是1
        image_L = image_L.reshape(-1, 28, 28, 1)
        prediction = model.predict(image_L)
        print(prediction)
        print(prediction[0])
        Final_prediction = np.argmax(prediction)
        a = 0
        for i in prediction[0]:
            print(a)
            print("percent:%.4f" % i)  # 输出百分比
            a = a + 1
        return Final_prediction


def main():
    Predict = PredictImg()
    res = Predict.pred("test.jpg")
    print("预测结果为:", res)


if __name__ == "__main__":
    main()

5.识别结果

(1)test.jpg 如下:
keras入门系列(二)——实现卷积神经网络手写数字的分类识别_第2张图片
(2)测试结果为:keras入门系列(二)——实现卷积神经网络手写数字的分类识别_第3张图片

你可能感兴趣的:(keras系列)