前段时间导师要求做一个神经网络可视化的项目,要将水果数据集进行训练得到模型,用于TensorSpace可视化。前前后后捣鼓了很久,现在回过头总结一下整个项目过程,写下这篇博客记录遇到的问题,有任何问题欢迎在评论区留言。
131
种水果分类(分的特别细致,比方说苹果(不同品种:深红色雪,金黄色,金红色,史密斯奶奶,粉红色淑女,红色,红色美味)):Apples (different varieties: Crimson Snow, Golden, Golden-Red, Granny Smith, Pink Lady, Red, Red Delicious), Apricot, Avocado, Avocado ripe, Banana (Yellow, Red, Lady Finger), Beetroot Red, Blueberry, Cactus fruit, Cantaloupe (2 varieties), Carambula, Cauliflower, Cherry (different varieties, Rainier), Cherry Wax (Yellow, Red, Black), Chestnut, Clementine, Cocos, Corn (with husk), Cucumber (ripened), Dates, Eggplant, Fig, Ginger Root, Granadilla, Grape (Blue, Pink, White (different varieties)), Grapefruit (Pink, White), Guava, Hazelnut, Huckleberry, Kiwi, Kaki, Kohlrabi, Kumsquats, Lemon (normal, Meyer), Lime, Lychee, Mandarine, Mango (Green, Red), Mangostan, Maracuja, Melon Piel de Sapo, Mulberry, Nectarine (Regular, Flat), Nut (Forest, Pecan), Onion (Red, White), Orange, Papaya, Passion fruit, Peach (different varieties), Pepino, Pear (different varieties, Abate, Forelle, Kaiser, Monster, Red, Stone, Williams), Pepper (Red, Green, Orange, Yellow), Physalis (normal, with Husk), Pineapple (normal, Mini), Pitahaya Red, Plum (different varieties), Pomegranate, Pomelo Sweetie, Potato (Red, Sweet, White), Quince, Rambutan, Raspberry, Redcurrant, Salak, Strawberry (normal, Wedge), Tamarillo, Tangelo, Tomato (different varieties, Maroon, Cherry Red, Yellow, not ripened, Heart), Walnut, Watermelon.实验用到了以下包:
// 导包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import cv2
import tensorflow as tf
import keras
import os
print(os.listdir("../dataset/Training"))
['Apple Braeburn', 'Apple Crimson Snow', 'Apple Golden 1', 'Apple Golden 2', 'Apple Golden 3', 'Apple Granny Smith', 'Apple Pink Lady', 'Apple Red 1', 'Apple Red 2', 'Apple Red 3', 'Apple Red Delicious', 'Apple Red Yellow 1', 'Apple Red Yellow 2', 'Apricot', 'Avocado', 'Avocado ripe', 'Banana', 'Banana Lady Finger', 'Banana Red', 'Beetroot', 'Blueberry', 'Cactus fruit', 'Cantaloupe 1', 'Cantaloupe 2', 'Carambula', 'Cauliflower', 'Cherry 1', 'Cherry 2', 'Cherry Rainier', 'Cherry Wax Black', 'Cherry Wax Red', 'Cherry Wax Yellow', 'Chestnut', 'Clementine', 'Cocos', 'Corn', 'Corn Husk', 'Cucumber Ripe', 'Cucumber Ripe 2', 'Dates', 'Eggplant', 'Fig', 'Ginger Root', 'Granadilla', 'Grape Blue', 'Grape Pink', 'Grape White', 'Grape White 2', 'Grape White 3', 'Grape White 4', 'Grapefruit Pink', 'Grapefruit White', 'Guava', 'Hazelnut', 'Huckleberry', 'Kaki', 'Kiwi', 'Kohlrabi', 'Kumquats', 'Lemon', 'Lemon Meyer', 'Limes', 'Lychee', 'Mandarine', 'Mango', 'Mango Red', 'Mangostan', 'Maracuja', 'Melon Piel de Sapo', 'Mulberry', 'Nectarine', 'Nectarine Flat', 'Nut Forest', 'Nut Pecan', 'Onion Red', 'Onion Red Peeled', 'Onion White', 'Orange', 'Papaya', 'Passion Fruit', 'Peach', 'Peach 2', 'Peach Flat', 'Pear', 'Pear 2', 'Pear Abate', 'Pear Forelle', 'Pear Kaiser', 'Pear Monster', 'Pear Red', 'Pear Stone', 'Pear Williams', 'Pepino', 'Pepper Green', 'Pepper Orange', 'Pepper Red', 'Pepper Yellow', 'Physalis', 'Physalis with Husk', 'Pineapple', 'Pineapple Mini', 'Pitahaya Red', 'Plum', 'Plum 2', 'Plum 3', 'Pomegranate', 'Pomelo Sweetie', 'Potato Red', 'Potato Red Washed', 'Potato Sweet', 'Potato White', 'Quince', 'Rambutan', 'Raspberry', 'Redcurrant', 'Salak', 'Strawberry', 'Strawberry Wedge', 'Tamarillo', 'Tangelo', 'Tomato 1', 'Tomato 2', 'Tomato 3', 'Tomato 4', 'Tomato Cherry Red', 'Tomato Heart', 'Tomato Maroon', 'Tomato not Ripened', 'Tomato Yellow', 'Walnut', 'Watermelon']
// 训练集处理
training_fruit_img = []
training_label = []
for dir_path in glob.glob("../dataset/Training/*"):
img_label = dir_path.split("/")[-1]
for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
img = cv2.imread(img_path)
img = cv2.resize(img, (64, 64))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
training_fruit_img.append(img)
training_label.append(img_label)
training_fruit_img = np.array(training_fruit_img)
training_label = np.array(training_label)
print(len(np.unique(training_label)))
131
// 测试集处理
test_fruit_img = []
test_label = []
for dir_path in glob.glob("../dataset/Test/*"):
img_label = dir_path.split("/")[-1]
for img_path in glob.glob(os.path.join(dir_path, "*.jpg")):
img = cv2.imread(img_path)
img = cv2.resize(img, (64, 64))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
test_fruit_img.append(img)
test_label.append(img_label)
test_fruit_img = np.array(test_fruit_img)
test_label = np.array(test_label)
print(len(np.unique(test_label)))
131
// 测试集(混合水果图像)处理
test_fruits_img = []
tests_label = []
for img_path in glob.glob(os.path.join("../dataset/test-multiple_fruits", "*.jpg")):
img_label = img_path.split("/")[-1]
img = cv2.imread(img_path)
img = cv2.resize(img, (64, 64))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
test_fruits_img.append(img)
tests_label.append(img_label)
test_fruits_img = np.array(test_fruits_img)
tests_label = np.array(tests_label)
len(np.unique(tests_label))
103
trainging_label_to_id = {v : k for k, v in enumerate(np.unique(training_label))}
training_id_to_label = {v : k for k, v in trainging_label_to_id.items()}
training_label_id = np.array([trainging_label_to_id[i] for i in training_label])
print(training_label_id)
array([ 0, 0, 0, ..., 130, 130, 130])
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 64, 64, 16) 448
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 32, 32, 32) 4640
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 16, 16, 32) 9248
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 32) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 8, 8, 64) 18496
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 256) 262400
_________________________________________________________________
dense_2 (Dense) (None, 131) 33667
=================================================================
Total params: 328,899
Trainable params: 328,899
Non-trainable params: 0
_________________________________________________________________
sparse_categorical_crossentropy
,优化器使用的是Adamax
,batch_size
设置为128,epochs
设置为10,并且使用TensorBoard
将训练过程可视化。Train on 67692 samples
Epoch 1/10
67692/67692 [==============================] - 145s 2ms/sample - loss: 1.6539 - accuracy: 0.5947
Epoch 2/10
67692/67692 [==============================] - 156s 2ms/sample - loss: 0.2184 - accuracy: 0.9370
Epoch 3/10
67692/67692 [==============================] - 179s 3ms/sample - loss: 0.0812 - accuracy: 0.9764
Epoch 4/10
67692/67692 [==============================] - 264s 4ms/sample - loss: 0.0466 - accuracy: 0.9864
Epoch 5/10
67692/67692 [==============================] - 272s 4ms/sample - loss: 0.0257 - accuracy: 0.9932
Epoch 6/10
67692/67692 [==============================] - 257s 4ms/sample - loss: 0.0160 - accuracy: 0.9958
Epoch 7/10
67692/67692 [==============================] - 301s 4ms/sample - loss: 0.0164 - accuracy: 0.9956
Epoch 8/10
67692/67692 [==============================] - 289s 4ms/sample - loss: 0.0094 - accuracy: 0.9976
Epoch 9/10
67692/67692 [==============================] - 252s 4ms/sample - loss: 0.0113 - accuracy: 0.9967
Epoch 10/10
67692/67692 [==============================] - 245s 4ms/sample - loss: 0.0062 - accuracy: 0.9982
22688/22688 [==============================] - 30s 1ms/step
Loss: 0.25497715034207047
Accuracy: 0.9393512010574341
model.save("../model/model_demo.h5")