使用AlexNet8网络实现10分类

参考资料: 北京大学, 软微学院, 曹健老师, 《人工智能实践:TensorFlow2.0笔记》
运行环境:
python3.7
tensorflow 2.1.0
numpy 1.17.4
matplotlib 3.2.1

AlexNet8的训练参考文章:AlexNet8网络在python下的实现
训练数据集:cifar10

值得注意的是只要替换掉Sequential中的模型,即可更换为VGG16、ResNet等其他模型

构造网络模型:

model = tf.keras.models.Sequential([
    # 网络结构
    Conv2D(filters=96, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=256, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Flatten(),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

加载w和b:
这里的路径使用前面训练的时候保存的地址

model.load_weights(model_save_path)

选择要识别的图片:

# 第1次输入选择需要识别的图片数目
preNum = int(input('input the number of test pictures:'))

for i in range(preNum):
    image_path = input('the name of the test picture:')
    
    # 第2次输入需要识别的图片名
    # 这里图片放到了当前目录下alexnet8_pictures文件夹下
    
    image_path = './alexnet8_pictures/' + image_path
    img = Image.open(image_path)

    # 显示图片
    image = plt.imread(image_path)
    plt.imshow(image)

    # 调低分辨率, 因为模型的输入是28*28分辨率的图片
    img = img.resize((28, 28), Image.ANTIALIAS)

    # 转换numpy
    img_arr = np.array(img)
    # print(img_arr.shape)

    img_arr = img_arr / 255

    # 转换为1*28*28*3
    x_predict = img_arr[tf.newaxis, ...]

	# 进行预测 
    result = model.predict(x_predict)

    # 取最大值作为预测结果
    print('result: ' + str(result))
    
	# axis=1代表跨列
    pred = tf.argmax(result, axis=1)
    category = str(pred.numpy())
    
    # 根据pred的值打印分类结果
    
    pc.print_category(category)
    print('\n')
    tf.print(pred)
    plt.pause(1)
    plt.close()

这里我拆分出了另一个.py用来打印结果
print_category.py

下面给出完整代码
完成预测:
alexnet8_app.py:

# alexnet8应用
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
import matplotlib.pyplot as plt
import print_category as pc
# model_save_path = './checkpoint/AlexNet8.ckpt'
model_save_path = './checkpoint/AlexNet8.ckpt'
model = tf.keras.models.Sequential([
    # 网络结构
    Conv2D(filters=96, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=256, kernel_size=(3, 3)),
    BatchNormalization(),
    Activation('relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu'),

    Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu'),
    MaxPool2D(pool_size=(3, 3), strides=2),

    Flatten(),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(2048, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

# 记载w和b
model.load_weights(model_save_path)

# 选择图片
preNum = int(input('input the number of test pictures:'))

for i in range(preNum):
    image_path = input('the name of the test picture:')
    image_path = './alexnet8_pictures/' + image_path
    img = Image.open(image_path)

    # 显示图片
    image = plt.imread(image_path)
    plt.imshow(image)

    # 调低分辨率
    img = img.resize((28, 28), Image.ANTIALIAS)

    # 转换numpy
    img_arr = np.array(img)
    # print(img_arr.shape)

    img_arr = img_arr / 255

    # 转换为1*28*28*3
    x_predict = img_arr[tf.newaxis, ...]

    result = model.predict(x_predict)

    # 取最大值作为预测结果
    print('result: ' + str(result))

    pred = tf.argmax(result, axis=1)
    category = str(pred.numpy())
    pc.print_category(category)
    print('\n')
    tf.print(pred)
    plt.pause(1)
    plt.close()

打印结果:
print_category.py:

def print_category(category):
    if category == '[0]':
        print('飞机')
    elif category == '[1]':
        print('汽车')
    elif category == '[2]':
        print('鸟')
    elif category == '[3]':
        print('猫')
    elif category == '[4]':
        print('鹿')
    elif category == '[5]':
        print('狗')
    elif category == '[6]':
        print('青蛙')
    elif category == '[7]':
        print('马')
    elif category == '[8]':
        print('船')
    else:
        print('卡车')

你可能感兴趣的:(使用AlexNet8网络实现10分类)