基于Fruits-360数据集构建CNN进行水果识别实验

基于Fruits-360数据集的水果识别项目

前段时间导师要求做一个神经网络可视化的项目,要将水果数据集进行训练得到模型,用于TensorSpace可视化。前前后后捣鼓了很久,现在回过头总结一下整个项目过程,写下这篇博客记录遇到的问题,有任何问题欢迎在评论区留言。

  • 1.实验数据集

  • (1)实验用的数据集是最常见的Fruits-360水果数据集,截至写博客为止这个数据集最新版本是2020.05.18.0
  • (2)该数据集有131种水果分类(分的特别细致,比方说苹果(不同品种:深红色雪,金黄色,金红色,史密斯奶奶,粉红色淑女,红色,红色美味)):Apples (different varieties: Crimson Snow, Golden, Golden-Red, Granny Smith, Pink Lady, Red, Red Delicious), Apricot, Avocado, Avocado ripe, Banana (Yellow, Red, Lady Finger), Beetroot Red, Blueberry, Cactus fruit, Cantaloupe (2 varieties), Carambula, Cauliflower, Cherry (different varieties, Rainier), Cherry Wax (Yellow, Red, Black), Chestnut, Clementine, Cocos, Corn (with husk), Cucumber (ripened), Dates, Eggplant, Fig, Ginger Root, Granadilla, Grape (Blue, Pink, White (different varieties)), Grapefruit (Pink, White), Guava, Hazelnut, Huckleberry, Kiwi, Kaki, Kohlrabi, Kumsquats, Lemon (normal, Meyer), Lime, Lychee, Mandarine, Mango (Green, Red), Mangostan, Maracuja, Melon Piel de Sapo, Mulberry, Nectarine (Regular, Flat), Nut (Forest, Pecan), Onion (Red, White), Orange, Papaya, Passion fruit, Peach (different varieties), Pepino, Pear (different varieties, Abate, Forelle, Kaiser, Monster, Red, Stone, Williams), Pepper (Red, Green, Orange, Yellow), Physalis (normal, with Husk), Pineapple (normal, Mini), Pitahaya Red, Plum (different varieties), Pomegranate, Pomelo Sweetie, Potato (Red, Sweet, White), Quince, Rambutan, Raspberry, Redcurrant, Salak, Strawberry (normal, Wedge), Tamarillo, Tangelo, Tomato (different varieties, Maroon, Cherry Red, Yellow, not ripened, Heart), Walnut, Watermelon.
  • (3)数据集属性
    • 图片总数:90483。
    • 训练集大小:67692张图像(每张图像一个水果)
    • 测试集大小:22688张图像(每张图像一个水果)
    • 种类数:131(水果)
    • 图片尺寸:100 x 100像素。
    • 文件名格式:图像索引_100.jpg(例如32_100.jpg)或r_图像索引_100.jpg(例如r_32_100.jpg)或r2_图像索引_100.jpg或r3_图像索引_100.jpg。“ r”代表旋转的水果。“ r2”表示水果绕第三轴旋转。“100”来自图像尺寸(100x100像素)。
    • 同一水果(例如苹果)的不同品种被存储为属于不同类别。
  • (4)数据集下载链接:Fruits-360.
  • 2.导包

实验用到了以下包:

// 导包
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import glob 
import cv2 
import tensorflow as tf
import keras
import os
print(os.listdir("../dataset/Training"))
['Apple Braeburn', 'Apple Crimson Snow', 'Apple Golden 1', 'Apple Golden 2', 'Apple Golden 3', 'Apple Granny Smith', 'Apple Pink Lady', 'Apple Red 1', 'Apple Red 2', 'Apple Red 3', 'Apple Red Delicious', 'Apple Red Yellow 1', 'Apple Red Yellow 2', 'Apricot', 'Avocado', 'Avocado ripe', 'Banana', 'Banana Lady Finger', 'Banana Red', 'Beetroot', 'Blueberry', 'Cactus fruit', 'Cantaloupe 1', 'Cantaloupe 2', 'Carambula', 'Cauliflower', 'Cherry 1', 'Cherry 2', 'Cherry Rainier', 'Cherry Wax Black', 'Cherry Wax Red', 'Cherry Wax Yellow', 'Chestnut', 'Clementine', 'Cocos', 'Corn', 'Corn Husk', 'Cucumber Ripe', 'Cucumber Ripe 2', 'Dates', 'Eggplant', 'Fig', 'Ginger Root', 'Granadilla', 'Grape Blue', 'Grape Pink', 'Grape White', 'Grape White 2', 'Grape White 3', 'Grape White 4', 'Grapefruit Pink', 'Grapefruit White', 'Guava', 'Hazelnut', 'Huckleberry', 'Kaki', 'Kiwi', 'Kohlrabi', 'Kumquats', 'Lemon', 'Lemon Meyer', 'Limes', 'Lychee', 'Mandarine', 'Mango', 'Mango Red', 'Mangostan', 'Maracuja', 'Melon Piel de Sapo', 'Mulberry', 'Nectarine', 'Nectarine Flat', 'Nut Forest', 'Nut Pecan', 'Onion Red', 'Onion Red Peeled', 'Onion White', 'Orange', 'Papaya', 'Passion Fruit', 'Peach', 'Peach 2', 'Peach Flat', 'Pear', 'Pear 2', 'Pear Abate', 'Pear Forelle', 'Pear Kaiser', 'Pear Monster', 'Pear Red', 'Pear Stone', 'Pear Williams', 'Pepino', 'Pepper Green', 'Pepper Orange', 'Pepper Red', 'Pepper Yellow', 'Physalis', 'Physalis with Husk', 'Pineapple', 'Pineapple Mini', 'Pitahaya Red', 'Plum', 'Plum 2', 'Plum 3', 'Pomegranate', 'Pomelo Sweetie', 'Potato Red', 'Potato Red Washed', 'Potato Sweet', 'Potato White', 'Quince', 'Rambutan', 'Raspberry', 'Redcurrant', 'Salak', 'Strawberry', 'Strawberry Wedge', 'Tamarillo', 'Tangelo', 'Tomato 1', 'Tomato 2', 'Tomato 3', 'Tomato 4', 'Tomato Cherry Red', 'Tomato Heart', 'Tomato Maroon', 'Tomato not Ripened', 'Tomato Yellow', 'Walnut', 'Watermelon']
  • 3.数据预处理

  • (1)训练集处理

// 训练集处理
training_fruit_img = []
training_label = []
for dir_path in glob.glob("../dataset/Training/*"):
    img_label = dir_path.split("/")[-1]
    for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
        img = cv2.imread(img_path)
        img = cv2.resize(img, (64, 64))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        training_fruit_img.append(img)
        training_label.append(img_label)
training_fruit_img = np.array(training_fruit_img)
training_label = np.array(training_label)
print(len(np.unique(training_label)))
131
  • (2)测试集处理

// 测试集处理
test_fruit_img = []
test_label = []
for dir_path in glob.glob("../dataset/Test/*"):
    img_label = dir_path.split("/")[-1]
    for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
        img = cv2.imread(img_path)
        img = cv2.resize(img, (64, 64))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        test_fruit_img.append(img)
        test_label.append(img_label)
test_fruit_img = np.array(test_fruit_img)
test_label = np.array(test_label)
print(len(np.unique(test_label)))
131
  • (3)测试集(混合水果图像)处理

// 测试集(混合水果图像)处理
test_fruits_img = []
tests_label = []
for img_path in glob.glob(os.path.join("../dataset/test-multiple_fruits", "*.jpg")):
     img_label = img_path.split("/")[-1]
     img = cv2.imread(img_path)
     img = cv2.resize(img, (64, 64))
     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
     test_fruits_img.append(img)
     tests_label.append(img_label)
test_fruits_img = np.array(test_fruits_img)
tests_label = np.array(tests_label)
len(np.unique(tests_label))
103
  • (4)训练集标签和id互相转化

trainging_label_to_id = {v : k for k, v in enumerate(np.unique(training_label))}
training_id_to_label = {v : k for k, v in trainging_label_to_id.items()}
training_label_id = np.array([trainging_label_to_id[i] for i in training_label])
print(training_label_id)
array([  0,   0,   0, ..., 130, 130, 130])
  • (5)测试集标签和id互相转化

  • 这部分同训练集,也是使用enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,即将测试集标签去重后组合为一个索引序列,并进行键值对对换,再转化为NumPy数组。
  • (6)像素值缩放

  • 这里我将训练集图像和测试集图像乘以1/255进行缩放,将像素值缩放至[0, 1]区间。
  • 然后使用matplotlib.pyplot将图像显示出来(实验中我以训练集第10001个图像为例)
    基于Fruits-360数据集构建CNN进行水果识别实验_第1张图片
  • 4.Keras模型构建

  • 我实验中的搭建的模型结构如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 64, 64, 16)        448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 32)        4640      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 32)        9248      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          18496     
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               262400    
_________________________________________________________________
dense_2 (Dense)              (None, 131)               33667     
=================================================================
Total params: 328,899
Trainable params: 328,899
Non-trainable params: 0
_________________________________________________________________
  • 5.模型编译和训练

  • 我实验中使用的是交叉熵损失函数sparse_categorical_crossentropy,优化器使用的是Adamaxbatch_size设置为128,epochs设置为10,并且使用TensorBoard将训练过程可视化。
Train on 67692 samples
Epoch 1/10
67692/67692 [==============================] - 145s 2ms/sample - loss: 1.6539 - accuracy: 0.5947
Epoch 2/10
67692/67692 [==============================] - 156s 2ms/sample - loss: 0.2184 - accuracy: 0.9370
Epoch 3/10
67692/67692 [==============================] - 179s 3ms/sample - loss: 0.0812 - accuracy: 0.9764
Epoch 4/10
67692/67692 [==============================] - 264s 4ms/sample - loss: 0.0466 - accuracy: 0.9864
Epoch 5/10
67692/67692 [==============================] - 272s 4ms/sample - loss: 0.0257 - accuracy: 0.9932
Epoch 6/10
67692/67692 [==============================] - 257s 4ms/sample - loss: 0.0160 - accuracy: 0.9958
Epoch 7/10
67692/67692 [==============================] - 301s 4ms/sample - loss: 0.0164 - accuracy: 0.9956
Epoch 8/10
67692/67692 [==============================] - 289s 4ms/sample - loss: 0.0094 - accuracy: 0.9976
Epoch 9/10
67692/67692 [==============================] - 252s 4ms/sample - loss: 0.0113 - accuracy: 0.9967
Epoch 10/10
67692/67692 [==============================] - 245s 4ms/sample - loss: 0.0062 - accuracy: 0.9982
  • 6.在测试集上评估模型

22688/22688 [==============================] - 30s 1ms/step


Loss: 0.25497715034207047
Accuracy: 0.9393512010574341
  • 可以看到该模型精确率达到93.94%,效果相对来说还算是不错的。
  • 7.保存模型

  • 将训练完成的模型进行保存,得到一个.h5文件,这个文件也是后续TensorSpace项目所需要用到的核心文件。
model.save("../model/model_demo.h5")

你可能感兴趣的:(深度学习,深度学习,计算机视觉,图像识别,可视化)