1、下载数据集
CIFAR-10是一个用于识别普适物 体的小型数据集,它包含了10个类 别的RGB彩色图片
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
cifar10=tf.keras.datasets.cifar10
(Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()
2、数据预处理
Xtrain_normalize=Xtrain.astype("float32")/255.0
Xtest_normalize=Xtrain.astype("float32")/255.0
Ytrain_ohe=keras.utils.to_categorical(Ytrain)
Ytest_ohe=keras.utils.to_categorical(Ytest)
3、建立卷神经网络CNN模型
图像的特征提取:通过卷积层1,降采样层1,卷积层2以及降采样层2的处理,提取图像的特征
全连接神经网络:全连接层、输出层所组成的网络结构
model = tf.keras.models.Sequential()
model.add(layers.Conv2D(filters=32,
kernel_size=(3, 3),
input_shape=(32, 32, 3),
activation='relu',
padding='same'))
model.add(tf.keras.layers.Dropout(rate=0.3))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Conv2D(filters=64,
kernel_size=(3, 3),
activation= 'relu',
padding='same'))
model.add(tf.keras.layers.Dropout(rate=0.3))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(10,activation='softmax'))
4、模型摘要
model.summary()
5、设置模型训练超参数
train_epochs=5
batch_size=100
6、设置模型训练模式
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
7、模型训练
model.fit(Xtrain_normalize,Ytrain_ohe,
validation_split = 0.2,
epochs = train_epochs,
batch_size=batch_size,
verbose = 2)
8、评估模型及预测
# 评估模型
test_loss, test_acc = model.evaluate(Xtest_normalize, Ytest_ohe)
print('Test accuracy:', test_acc)
# 进行预测
predictions = model.predict(Xtest_normalize)
# 定义标签字典 每一个数字所代表的图像类别的名称
label_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer",
5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
9、可视化预测结果
# 定义显示图像数据及其对应标签的函数
def plot_images_labels_prediction(images, # 图像列表
labels, # 标签列表
prediction, # 预测值列表
index, # 从第index个开始显示
num = 5 ): # 缺省一次显示5幅
fig = plt.gcf() # 获取当前图表,Get Current Figure
fig.set_size_inches(12, 6) # 1英寸等于 2.54 cm
if num > 10:
num = 10 # 最多显示10个子图
for i in range(0, num):
ax = plt.subplot(2, 5, i + 1) # 获取当前要处理的子图
ax.imshow(images[index], # 显示第index个图像
cmap = 'binary')
title = str(i) + ',' + label_dict[np.argmax(labels[index])] # 构建该图上要显示的title信息
if len(prediction) > 0:
title += ' => ' + label_dict[np.argmax(predictions[index])]
ax.set_title(title,fontsize = 10) # 显示图上的title信息
index += 1
plt.show()
plot_images_labels_prediction(Xtest_normalize,
Ytest_ohe,
predictions,0,10)