import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
2、获取数据集并作预处理
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
#图像被编码为Numpy数组,而标签只是一组数字,从0到9,图像和标签之间存在一一对应关系
#输出训练集的特征的大小
#reshape:给数据增加一个维度,使数据和网络结构匹配
Train_images = tf.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
print(Train_images.shape)
Test_images = tf.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))
print(Test_images.shape)
3、搭建LeNet网络
net = tf.keras.models.Sequential([
#第一层:6个5*5的卷积核,全0填充;最大池化,2*2的池化核,步长为2,padding='VALID'
tf.keras.layers.Conv2D(filters=6,kernel_size=5,activation='sigmoid',input_shape=(28,28,1),padding='same'),
tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
#第二层
tf.keras.layers.Conv2D(filters=16,kernel_size=5,activation='sigmoid',padding='same'),
tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
#拉直,将(28,28)像素的图像即对应的2维的数组转成一维的数组
tf.keras.layers.Flatten(),
#三层全连接网络
#120个神经元
tf.keras.layers.Dense(120,activation='sigmoid'),
tf.keras.layers.Dense(84,activation='sigmoid'),
tf.keras.layers.Dense(10,activation='softmax')
])
4、训练
#损失函数和训练算法采用交叉熵损失函数(cross entropy)和小批量随机梯度下降(SGD)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.9,momentum=0.0,nesterov=False)
#编译模型
net.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#训练模型,训练次数为5次
#validation_split用来指定训练集的一定比例数据为验证集
net.fit(Train_images,train_labels,epochs=5,validation_split=0.1)
5、评估准确率
test_loss, test_acc = net.evaluate(Test_images, test_labels)
print('\nTest accuracy:', test_acc)
6、对测试集图片进行预测
#对测试集图片进行预测
Predictions=net.predict(Test_images)
#输出第一张图片的预测结果
print(Predictions[0])
print("The first picture's prediction is:",np.argmax(Predictions[0]))
print("the first picture is:",test_labels[0])
7、绘制预测结果
#类别名称
class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat',
'Sandal','Shirt','Sneaker','Bag','Ankle boot']
#绘制前25个样本的预测结果,正确为绿色,不正确为红色
plt.figure(figsize=(15, 15))
for i in range(25):
plt.subplot(5, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(test_images[i], cmap=plt.cm.binary)
predicted_label = np.argmax(Predictions[i])
true_label = test_labels[i]
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
#前面是预测分类,括号内是实际分类
plt.xlabel("{} ({})".format(class_names[predicted_label],
class_names[true_label]),
color=color)
plt.show()
参考:https://www.cnblogs.com/CuteyThyme/p/12741241.html