使用Lenet5对mnist数据集进行训练和测试

1.先训练数据

import tensorflow as tf
from tensorflow.keras import Sequential, layers, optimizers

# 加载数据集
mnist = tf.keras.datasets.mnist
(trainImage, trainLabel),(testImage, testLabel) = mnist.load_data()
 
for i in [trainImage,trainLabel,testImage,testLabel]:
    print(i.shape)

trainImage = tf.reshape(trainImage,(60000,28,28,1))
testImage = tf.reshape(testImage,(10000,28,28,1))
 
for i in [trainImage,trainLabel,testImage,testLabel]:
    print(i.shape)

# 网络定义
network = Sequential([
    # 卷积层1
    layers.Conv2D(filters=6,kernel_size=(5,5),activation="relu",input_shape=(28,28,1),padding="same"),
    layers.MaxPool2D(pool_size=(2,2),strides=2),
    
    # 卷积层2
    layers.Conv2D(filters=16,kernel_size=(5,5),activation="relu",padding="same"),
    layers.MaxPool2D(pool_size=2,strides=2),
    
    # 卷积层3
    layers.Conv2D(filters=32,kernel_size=(5,5),activation="relu",padding="same"),
    
    layers.Flatten(),
    
    # 全连接层1
    layers.Dense(200,activation="relu"),
    
    # 全连接层2
    layers.Dense(10,activation="softmax")    
])
network.summary()

# 模型训练 训练30个epoch
network.compile(optimizer='adam',loss="sparse_categorical_crossentropy",metrics=["accuracy"])
network.fit(trainImage,trainLabel,epochs=30,validation_split=0.1)

# 模型保存
network.save('./lenet_mnist.h5')
print('lenet_mnist model saved')
del network

2.生成数据集

使用Lenet5对mnist数据集进行训练和测试_第1张图片
#3. 测试数据集代码

import tensorflow as tf 
from tensorflow import keras 
import matplotlib.pyplot as plt

# 网络加载
network = keras.models.load_model('lenet_mnist.h5')
network.summary()

# 读取数据集
mnist = tf.keras.datasets.mnist
(trainImage, trainLabel),(testImage, testLabel) = mnist.load_data()
 
for i in [trainImage,trainLabel,testImage,testLabel]:
    print(i.shape)

# 显示前25张图片
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.imshow(testImage[i], cmap='gray')
plt.show()
# 改变维度 
testImage = tf.reshape(testImage,(10000,28,28,1))
# 结果预测
result = network.predict(testImage)[0:25]
pred = tf.argmax(result, axis=1)
pred_list=[]
for item in pred:
    pred_list.append(item.numpy())
print(pred_list)

4. 结果展示

使用Lenet5对mnist数据集进行训练和测试_第2张图片

5. 结语

希望对大家有所帮助!!!

你可能感兴趣的:(python,计算机视觉)