tensorflow2实现LeNet-5网络

利用lenet网络实现图片分类

网路分类
tensorflow2实现LeNet-5网络_第1张图片
1、导入相关包

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

2、获取数据集并作预处理

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
#图像被编码为Numpy数组,而标签只是一组数字,从0到9,图像和标签之间存在一一对应关系

#输出训练集的特征的大小
#reshape:给数据增加一个维度,使数据和网络结构匹配
Train_images = tf.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
print(Train_images.shape)
Test_images = tf.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))
print(Test_images.shape)

3、搭建LeNet网络

net = tf.keras.models.Sequential([
    #第一层:6个5*5的卷积核,全0填充;最大池化,2*2的池化核,步长为2,padding='VALID'
    tf.keras.layers.Conv2D(filters=6,kernel_size=5,activation='sigmoid',input_shape=(28,28,1),padding='same'),
    tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
    #第二层
    tf.keras.layers.Conv2D(filters=16,kernel_size=5,activation='sigmoid',padding='same'),
    tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
    #拉直,将(28,28)像素的图像即对应的2维的数组转成一维的数组
    tf.keras.layers.Flatten(),
    #三层全连接网络
    #120个神经元
    tf.keras.layers.Dense(120,activation='sigmoid'),
    tf.keras.layers.Dense(84,activation='sigmoid'),
    tf.keras.layers.Dense(10,activation='softmax')
])

4、训练

#损失函数和训练算法采用交叉熵损失函数(cross entropy)和小批量随机梯度下降(SGD)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.9,momentum=0.0,nesterov=False)

#编译模型
net.compile(optimizer=optimizer,
           loss='sparse_categorical_crossentropy',
           metrics=['accuracy'])

#训练模型,训练次数为5次
#validation_split用来指定训练集的一定比例数据为验证集
net.fit(Train_images,train_labels,epochs=5,validation_split=0.1)

5、评估准确率

test_loss, test_acc = net.evaluate(Test_images, test_labels)
print('\nTest accuracy:', test_acc)

6、对测试集图片进行预测

#对测试集图片进行预测
Predictions=net.predict(Test_images)
#输出第一张图片的预测结果
print(Predictions[0])
print("The first picture's prediction is:",np.argmax(Predictions[0]))
print("the first picture is:",test_labels[0])

7、绘制预测结果

#类别名称
class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat',
             'Sandal','Shirt','Sneaker','Bag','Ankle boot']
#绘制前25个样本的预测结果,正确为绿色,不正确为红色
plt.figure(figsize=(15, 15))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid('off')
    plt.imshow(test_images[i], cmap=plt.cm.binary)
    predicted_label = np.argmax(Predictions[i])
    true_label = test_labels[i]
    if predicted_label == true_label:
        color = 'green'
    else:
        color = 'red'
    #前面是预测分类,括号内是实际分类
    plt.xlabel("{} ({})".format(class_names[predicted_label],
                                class_names[true_label]),
               color=color)
plt.show()

参考:https://www.cnblogs.com/CuteyThyme/p/12741241.html

你可能感兴趣的:(tensorflow2实现LeNet-5网络)