TensorFlow2 实现动物识别(90类)MobileNetV2算法(内附源码与数据)

本文已加入 Python AI 计划,从一个Python小白到一个AI大神,你所需要的所有知识都在 这里 了。


在之前的文章中我们通过Xception算法模型实现了狗、猫、鸡、马四种的动物的识别(新模型!实现动物识别)。今天我们接着介绍MobileNetV2算法,将数据集扩充到90个类别,即使用 90 个不同类别的动物图片,每个类别分别含有60张图片,一共 5400 张图片进行识别。最后达到的准确率是86.2% 。代码与数据我放在文末了,需要的自取

我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter lab
  • 深度学习环境:TensorFlow2.4.1
  • 选自专栏:《深度学习100例》

我们的代码流程图如下所示:

TensorFlow2 实现动物识别(90类)MobileNetV2算法(内附源码与数据)_第1张图片

文章目录

  • 一、设置GPU
  • 二、导入数据
    • 1. 查看数据
    • 2. 加载数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建MobileNetV2迁移模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keras

warnings.filterwarnings("ignore")#忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

二、导入数据

1. 查看数据

import pathlib

data_dir = "./30-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 5400
# 统计每一个类别的数目
class_name = []
class_sum  = []

for i in data_dir.glob('*'):

    class_name.append(str(i).split("\\")[1])
    class_sum.append(len(list(i.glob('*'))))

class_dict = {'class_name':class_name,'class_sum':class_sum}
class_df   = pd.DataFrame(class_dict,columns=['class_name', 'class_sum'])
# 按照图片数量进行降序排序
class_df = class_df.sort_values(by="class_sum" , ascending=False)
class_df.head()
class_name class_sum
0 antelope 60
67 raccoon 60
65 porcupine 60
64 pigeon 60
63 pig 60

在实验开始时查看数据集分布情况,部分类别图片数量过少时,需要及时补充数据。

2. 加载数据

batch_size = 16
img_height = 224
img_width  = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789

通过该方法导入数据时,会同时打乱数据
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 5400 files belonging to 90 classes.
Using 4320 files for training.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789

通过该方法导入数据时,会同时打乱数据
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 5400 files belonging to 90 classes.
Using 1080 files for validation.
class_names = train_ds.class_names
print("数据类别有:",class_names)
print("需要识别的动物一共有%d类"%len(class_names))
数据类别有: ['antelope', 'badger', 'bat', 'bear', 'bee', 'beetle', 'bison', 'boar', 'butterfly', 'cat', 'caterpillar', 'chimpanzee', 'cockroach', 'cow', 'coyote', 'crab', 'crow', 'deer', 'dog', 'dolphin', 'donkey', 'dragonfly', 'duck', 'eagle', 'elephant', 'flamingo', 'fly', 'fox', 'goat', 'goldfish', 'goose', 'gorilla', 'grasshopper', 'hamster', 'hare', 'hedgehog', 'hippopotamus', 'hornbill', 'horse', 'hummingbird', 'hyena', 'jellyfish', 'kangaroo', 'koala', 'ladybugs', 'leopard', 'lion', 'lizard', 'lobster', 'mosquito', 'moth', 'mouse', 'octopus', 'okapi', 'orangutan', 'otter', 'owl', 'ox', 'oyster', 'panda', 'parrot', 'pelecaniformes', 'penguin', 'pig', 'pigeon', 'porcupine', 'possum', 'raccoon', 'rat', 'reindeer', 'rhinoceros', 'sandpiper', 'seahorse', 'seal', 'shark', 'sheep', 'snake', 'sparrow', 'squid', 'squirrel', 'starfish', 'swan', 'tiger', 'turkey', 'turtle', 'whale', 'wolf', 'wombat', 'woodpecker', 'zebra']
需要识别的动物一共有90类
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(16, 224, 224, 3)
(16,)

3. 配置数据集

  • shuffle() : 打乱数据。
  • prefetch() : 预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() : 将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
#     .shuffle(2000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
#     .shuffle(2000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

TensorFlow2 实现动物识别(90类)MobileNetV2算法(内附源码与数据)_第2张图片

三、构建MobileNetV2迁移模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activation

# 加载预训练模型
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(weights='imagenet',
                                                            include_top=False,
                                                            input_shape=(img_width,img_height,3),
                                                            pooling='max')

for layer in base_model.layers:
    layer.trainable = True
    
X = base_model.output
"""
注意到原模型(MobileNetV2)会发生过拟合现象,这里加上一个Dropout层
加上后,过拟合现象得到了明显的改善。
大家可以试着通过调整代码,观察一下注释Dropout层与不注释之间的差别
"""
X = Dropout(0.6)(X)

output = Dense(len(class_names), activation='softmax')(X)
model = Model(inputs=base_model.input, outputs=output)

# model.summary()

四、编译

optimizer = tf.keras.optimizers.Adam(lr=1e-4)

model.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateScheduler

NO_EPOCHS = 20
PATIENCE  = 5
VERBOSE   = 1

# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-4 * 0.98 ** x)

# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)

# 
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=VERBOSE,
                                save_best_only=True,
                                save_weights_only=True)
train_model  = model.fit(train_ds,
                  epochs=NO_EPOCHS,
                  verbose=1,
                  validation_data=val_ds,
                  callbacks=[annealer, earlystopper, checkpointer])
Epoch 1/20
270/270 [==============================] - 24s 65ms/step - loss: 9.6472 - accuracy: 0.0667 - val_loss: 6.1488 - val_accuracy: 0.1407racy - ETA: 13s - loss: 16.8260 - ac - ETA: 5s - loss: 11.5347

Epoch 00001: val_accuracy improved from -inf to 0.14074, saving model to best_model.h5
Epoch 2/20
270/270 [==============================] - 16s 58ms/step - loss: 3.1554 - accuracy: 0.3285 - val_loss: 2.4029 - val_accuracy: 0.4852
..........
Epoch 00019: val_accuracy did not improve from 0.84167
Epoch 20/20
270/270 [==============================] - 16s 57ms/step - loss: 0.0514 - accuracy: 0.9843 - val_loss: 0.7325 - val_accuracy: 0.8250

Epoch 00020: val_accuracy did not improve from 0.84167

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']

loss = train_model.history['loss']
val_loss = train_model.history['val_loss']

epochs_range = range(len(acc))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

TensorFlow2 实现动物识别(90类)MobileNetV2算法(内附源码与数据)_第3张图片

加入Dropout层后过拟合现象得到了缓解,没有那么明显了。

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('混淆矩阵',fontsize=15)
    plt.ylabel('真实值',fontsize=14)
    plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []

for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
    for image, label in zip(images, labels):
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(image, 0) 
        # 使用模型预测图片中的人物
        prediction = model.predict(img_array)

        val_pre.append(class_names[np.argmax(prediction)])
        val_label.append(class_names[label])
plot_cm(val_label, val_pre)

TensorFlow2 实现动物识别(90类)MobileNetV2算法(内附源码与数据)_第4张图片
90个类别做成混淆矩阵,基本就看不出东西了,这里放上混淆矩阵的代码主要是方便大家切换成其他数据集时使用。

3. 各项指标评估

from sklearn import metrics

def test_accuracy_report(model):
    print(metrics.classification_report(val_label, val_pre, target_names=class_names)) 
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model)
                precision    recall  f1-score   support

      antelope       0.52      1.00      0.68        15
        badger       1.00      0.83      0.91        12
           bat       0.55      0.50      0.52        12
           
          ...此处省略若干...
          
         zebra       0.80      0.92      0.86        13

      accuracy                           0.82      1080
     macro avg       0.85      0.83      0.83      1080
  weighted avg       0.85      0.82      0.82      1080

Loss function: 0.7324997782707214, accuracy: 0.824999988079071

本文的数据与代码传送门


你可能感兴趣的:(深度学习100例,算法,python,深度学习)