参考资料: 北京大学, 软微学院, 曹健老师, 《人工智能实践:TensorFlow2.0笔记》
运行环境:
python3.7
tensorflow 2.1.0
numpy 1.17.4
matplotlib 3.2.1
AlexNet8的训练参考文章:AlexNet8网络在python下的实现
训练数据集:cifar10
值得注意的是只要替换掉Sequential中的模型,即可更换为VGG16、ResNet等其他模型
构造网络模型:
model = tf.keras.models.Sequential([
# 网络结构
Conv2D(filters=96, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=256, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Flatten(),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
加载w和b:
这里的路径使用前面训练的时候保存的地址
model.load_weights(model_save_path)
选择要识别的图片:
# 第1次输入选择需要识别的图片数目
preNum = int(input('input the number of test pictures:'))
for i in range(preNum):
image_path = input('the name of the test picture:')
# 第2次输入需要识别的图片名
# 这里图片放到了当前目录下alexnet8_pictures文件夹下
image_path = './alexnet8_pictures/' + image_path
img = Image.open(image_path)
# 显示图片
image = plt.imread(image_path)
plt.imshow(image)
# 调低分辨率, 因为模型的输入是28*28分辨率的图片
img = img.resize((28, 28), Image.ANTIALIAS)
# 转换numpy
img_arr = np.array(img)
# print(img_arr.shape)
img_arr = img_arr / 255
# 转换为1*28*28*3
x_predict = img_arr[tf.newaxis, ...]
# 进行预测
result = model.predict(x_predict)
# 取最大值作为预测结果
print('result: ' + str(result))
# axis=1代表跨列
pred = tf.argmax(result, axis=1)
category = str(pred.numpy())
# 根据pred的值打印分类结果
pc.print_category(category)
print('\n')
tf.print(pred)
plt.pause(1)
plt.close()
这里我拆分出了另一个.py用来打印结果
print_category.py
下面给出完整代码
完成预测:
alexnet8_app.py:
# alexnet8应用
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
import matplotlib.pyplot as plt
import print_category as pc
# model_save_path = './checkpoint/AlexNet8.ckpt'
model_save_path = './checkpoint/AlexNet8.ckpt'
model = tf.keras.models.Sequential([
# 网络结构
Conv2D(filters=96, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=256, kernel_size=(3, 3)),
BatchNormalization(),
Activation('relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),
Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
MaxPool2D(pool_size=(3, 3), strides=2),
Flatten(),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(2048, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# 记载w和b
model.load_weights(model_save_path)
# 选择图片
preNum = int(input('input the number of test pictures:'))
for i in range(preNum):
image_path = input('the name of the test picture:')
image_path = './alexnet8_pictures/' + image_path
img = Image.open(image_path)
# 显示图片
image = plt.imread(image_path)
plt.imshow(image)
# 调低分辨率
img = img.resize((28, 28), Image.ANTIALIAS)
# 转换numpy
img_arr = np.array(img)
# print(img_arr.shape)
img_arr = img_arr / 255
# 转换为1*28*28*3
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
# 取最大值作为预测结果
print('result: ' + str(result))
pred = tf.argmax(result, axis=1)
category = str(pred.numpy())
pc.print_category(category)
print('\n')
tf.print(pred)
plt.pause(1)
plt.close()
打印结果:
print_category.py:
def print_category(category):
if category == '[0]':
print('飞机')
elif category == '[1]':
print('汽车')
elif category == '[2]':
print('鸟')
elif category == '[3]':
print('猫')
elif category == '[4]':
print('鹿')
elif category == '[5]':
print('狗')
elif category == '[6]':
print('青蛙')
elif category == '[7]':
print('马')
elif category == '[8]':
print('船')
else:
print('卡车')